VyLala commited on
Commit
8835144
·
verified ·
1 Parent(s): f3b0e2e

Upload 52 files

Browse files

update 28_7_2025

.gitattributes CHANGED
@@ -1,37 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- credentials.json filter=crypt diff=crypt
37
-
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ credentials.json filter=crypt diff=crypt
37
+
NER/PDF/__pycache__/pdf.cpython-311.pyc CHANGED
Binary files a/NER/PDF/__pycache__/pdf.cpython-311.pyc and b/NER/PDF/__pycache__/pdf.cpython-311.pyc differ
 
NER/PDF/pdf.py CHANGED
@@ -6,86 +6,128 @@ from bs4 import BeautifulSoup
6
  import requests
7
  from NER import cleanText
8
  #!pip install tabula-py
9
-
10
  import tabula
11
- class PDF(): # using PyPDF2
 
 
 
12
  def __init__(self, pdf, saveFolder, doi=None):
13
  self.pdf = pdf
14
  self.doi = doi
15
  self.saveFolder = saveFolder
 
16
  def openPDFFile(self):
17
  if "https" in self.pdf:
18
  name = self.pdf.split("/")[-1]
19
  name = self.downloadPDF(self.saveFolder)
20
  if name != "no pdfLink to download":
21
- fileToOpen = self.saveFolder + "/" + name
22
- else: fileToOpen = self.pdf
23
- else: fileToOpen = self.pdf
 
 
24
  return open(fileToOpen, "rb")
 
25
  def downloadPDF(self, saveFolder):
26
  pdfLink = ''
27
- if ".pdf" not in self.pdf and "https" not in self.pdf: # the download link is a general URL not pdf link
28
  r = requests.get(self.pdf)
29
  soup = BeautifulSoup(r.content, 'html.parser')
30
  links = soup.find_all("a")
31
  for link in links:
32
- if ".pdf" in link.get("href"):
33
  if self.doi in link.get("href"):
34
  pdfLink = link.get("href")
35
  break
36
  else:
37
  pdfLink = self.pdf
 
38
  if pdfLink != '':
39
  response = requests.get(pdfLink)
40
  name = pdfLink.split("/")[-1]
41
- pdf = open(saveFolder+"/"+name, 'wb')
42
- pdf.write(response.content)
43
- pdf.close()
 
 
44
  print("pdf downloaded")
45
  return name
46
  else:
47
  return "no pdfLink to download"
 
48
  def extractText(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  jsonPage = {}
50
- pdf = self.openPDFFile()
51
- doc = PDFDocument(pdf)
52
- viewer = SimplePDFViewer(pdf)
53
- all_pages = [p for p in doc.pages()]
54
- cl = cleanText.cleanGenText()
55
- for page in range(1,len(all_pages)):
56
- viewer.navigate(page)
57
- viewer.render()
58
- if str(page) not in jsonPage:
59
- jsonPage[str(page)] = {}
60
- # text
61
- text = "".join(viewer.canvas.strings)
62
- clean, filteredWord = cl.textPreprocessing(text) #cleanText.cleanGenText(text).cleanText()
63
- # save the text of filtered words which remove "a", the, "an", "is", etc.
64
- jsonPage[str(page)]["normalText"] = [text]
65
- jsonPage[str(page)]["cleanText"] = [' '.join(filteredWord)]
66
- #image
67
- image = viewer.canvas.images
68
- jsonPage[str(page)]["image"] = [image]
69
- #form
70
- form = viewer.canvas.forms
71
- jsonPage[str(page)]["form"] = [form]
72
- # content based on PDF adobe
73
- content = viewer.canvas.text_content
74
- jsonPage[str(page)]["content"] = [content]
75
- # inline_image:'''
76
- '''Inline images are aligned with the text,
77
- and are usually content images like photos, charts, or graphs.'''
78
- inline_image = viewer.canvas.inline_images
79
- jsonPage[str(page)]["inline_image"] = [inline_image]
80
- pdf.close()
81
- '''Output Format:
82
- jsonPage[str(page)]["normalText"]
83
- jsonPage[str(page)]["cleanText"]
84
- jsonPage[str(page)]["image"]
85
- jsonPage[str(page)]["form"]
86
- jsonPage[str(page)]["content"]'''
87
- return jsonPage
88
- def extractTable(self,pages,saveFile=None,outputFormat=None):
89
  '''pages (str, int, iterable of int, optional) –
90
  An optional values specifying pages to extract from. It allows str,`int`, iterable of :int. Default: 1
91
  Examples: '1-2,3', 'all', [1,2]'''
@@ -106,37 +148,40 @@ class PDF(): # using PyPDF2
106
  df = []
107
  print("No tables found in PDF file")
108
  return df
109
- def mergeTextinJson(self,jsonPDF):
110
- # pdf
111
- #cl = cleanGenText()
112
  cl = cleanText.cleanGenText()
113
  pdfText = ""
114
- for page in jsonPDF:
115
- # page is "\n\n"
116
- if len(jsonPDF[page]["normalText"]) > 0:
117
- for i in range(len(jsonPDF[page]["normalText"])):
118
- text = jsonPDF[page]["normalText"][i]
119
- if len(text)>0:
120
- text = cl.removeTabWhiteSpaceNewLine(text)
121
- text = cl.removeExtraSpaceBetweenWords(text)
122
- jsonPDF[page]["normalText"][i] = text
123
- # same page is just a dot.
124
- if i-1 > 0:
125
- if jsonPDF[page]["normalText"][i-1][-1] != ".":
126
- pdfText += ". "
127
- pdfText += jsonPDF[page]["normalText"][i]
128
- if len(jsonPDF[page]["normalText"][i])>0:
129
- if jsonPDF[page]["normalText"][i][-1]!=".":
130
- pdfText += "."
131
- pdfText += "\n\n"
132
  return pdfText
 
133
  def getReference(self):
134
  pass
 
135
  def getSupMaterial(self):
136
  pass
 
137
  def removeHeaders(self):
138
  pass
 
139
  def removeFooters(self):
140
  pass
 
141
  def removeReference(self):
142
  pass
 
6
  import requests
7
  from NER import cleanText
8
  #!pip install tabula-py
 
9
  import tabula
10
+ import fitz # PyMuPDF
11
+ import os
12
+
13
+ class PDF():
14
  def __init__(self, pdf, saveFolder, doi=None):
15
  self.pdf = pdf
16
  self.doi = doi
17
  self.saveFolder = saveFolder
18
+
19
  def openPDFFile(self):
20
  if "https" in self.pdf:
21
  name = self.pdf.split("/")[-1]
22
  name = self.downloadPDF(self.saveFolder)
23
  if name != "no pdfLink to download":
24
+ fileToOpen = os.path.join(self.saveFolder, name)
25
+ else:
26
+ fileToOpen = self.pdf
27
+ else:
28
+ fileToOpen = self.pdf
29
  return open(fileToOpen, "rb")
30
+
31
  def downloadPDF(self, saveFolder):
32
  pdfLink = ''
33
+ if ".pdf" not in self.pdf and "https" not in self.pdf:
34
  r = requests.get(self.pdf)
35
  soup = BeautifulSoup(r.content, 'html.parser')
36
  links = soup.find_all("a")
37
  for link in links:
38
+ if ".pdf" in link.get("href", ""):
39
  if self.doi in link.get("href"):
40
  pdfLink = link.get("href")
41
  break
42
  else:
43
  pdfLink = self.pdf
44
+
45
  if pdfLink != '':
46
  response = requests.get(pdfLink)
47
  name = pdfLink.split("/")[-1]
48
+ print("inside download PDF and name and link are: ", pdfLink, name)
49
+ print("saveFolder is: ", saveFolder)
50
+ with open(os.path.join(saveFolder, name), 'wb') as pdf:
51
+ print("len of response content: ", len(response.content))
52
+ pdf.write(response.content)
53
  print("pdf downloaded")
54
  return name
55
  else:
56
  return "no pdfLink to download"
57
+
58
  def extractText(self):
59
+ fileToOpen = self.openPDFFile().name
60
+ try:
61
+ doc = fitz.open(fileToOpen)
62
+ text = ""
63
+ for page in doc:
64
+ text += page.get_text("text") + "\n\n"
65
+ doc.close()
66
+
67
+ if len(text.strip()) < 100:
68
+ print("Fallback to PDFReader due to weak text extraction.")
69
+ text = self.extractTextWithPDFReader()
70
+ return text
71
+ except Exception as e:
72
+ print("Failed with PyMuPDF, fallback to PDFReader:", e)
73
+ return self.extractTextWithPDFReader()
74
+ def extract_text_excluding_tables(self):
75
+ fileToOpen = self.openPDFFile().name
76
+ text = ""
77
+ try:
78
+ doc = fitz.open(fileToOpen)
79
+ for page in doc:
80
+ blocks = page.get_text("dict")["blocks"]
81
+
82
+ for block in blocks:
83
+ if block["type"] == 0: # text block
84
+ lines = block.get("lines", [])
85
+
86
+ if not lines:
87
+ continue
88
+ avg_words_per_line = sum(len(l["spans"]) for l in lines) / len(lines)
89
+ if avg_words_per_line > 1: # Heuristic: paragraph-like blocks
90
+ for line in lines:
91
+ text += " ".join(span["text"] for span in line["spans"]) + "\n"
92
+ doc.close()
93
+ if len(text.strip()) < 100:
94
+ print("Fallback to PDFReader due to weak text extraction.")
95
+ text = self.extractTextWithPDFReader()
96
+ return text
97
+ except Exception as e:
98
+ print("Failed with PyMuPDF, fallback to PDFReader:", e)
99
+ return self.extractTextWithPDFReader()
100
+
101
+ def extractTextWithPDFReader(self):
102
  jsonPage = {}
103
+ try:
104
+ pdf = self.openPDFFile()
105
+ print("open pdf file")
106
+ print(pdf)
107
+ doc = PDFDocument(pdf)
108
+ viewer = SimplePDFViewer(pdf)
109
+ all_pages = [p for p in doc.pages()]
110
+ cl = cleanText.cleanGenText()
111
+ pdfText = ""
112
+ for page in range(1, len(all_pages)):
113
+ viewer.navigate(page)
114
+ viewer.render()
115
+ if str(page) not in jsonPage:
116
+ jsonPage[str(page)] = {}
117
+ text = "".join(viewer.canvas.strings)
118
+ clean, filteredWord = cl.textPreprocessing(text)
119
+ jsonPage[str(page)]["normalText"] = [text]
120
+ jsonPage[str(page)]["cleanText"] = [' '.join(filteredWord)]
121
+ jsonPage[str(page)]["image"] = [viewer.canvas.images]
122
+ jsonPage[str(page)]["form"] = [viewer.canvas.forms]
123
+ jsonPage[str(page)]["content"] = [viewer.canvas.text_content]
124
+ jsonPage[str(page)]["inline_image"] = [viewer.canvas.inline_images]
125
+ pdf.close()
126
+ except:
127
+ jsonPage = {}
128
+ return self.mergeTextinJson(jsonPage)
129
+
130
+ def extractTable(self,pages="all",saveFile=None,outputFormat=None):
 
 
 
 
 
 
 
 
 
 
 
131
  '''pages (str, int, iterable of int, optional) –
132
  An optional values specifying pages to extract from. It allows str,`int`, iterable of :int. Default: 1
133
  Examples: '1-2,3', 'all', [1,2]'''
 
148
  df = []
149
  print("No tables found in PDF file")
150
  return df
151
+
152
+ def mergeTextinJson(self, jsonPDF):
 
153
  cl = cleanText.cleanGenText()
154
  pdfText = ""
155
+ if jsonPDF:
156
+ for page in jsonPDF:
157
+ if len(jsonPDF[page]["normalText"]) > 0:
158
+ for i in range(len(jsonPDF[page]["normalText"])):
159
+ text = jsonPDF[page]["normalText"][i]
160
+ if len(text) > 0:
161
+ text = cl.removeTabWhiteSpaceNewLine(text)
162
+ text = cl.removeExtraSpaceBetweenWords(text)
163
+ jsonPDF[page]["normalText"][i] = text
164
+ if i - 1 > 0:
165
+ if jsonPDF[page]["normalText"][i - 1][-1] != ".":
166
+ pdfText += ". "
167
+ pdfText += jsonPDF[page]["normalText"][i]
168
+ if len(jsonPDF[page]["normalText"][i]) > 0:
169
+ if jsonPDF[page]["normalText"][i][-1] != ".":
170
+ pdfText += "."
171
+ pdfText += "\n\n"
 
172
  return pdfText
173
+
174
  def getReference(self):
175
  pass
176
+
177
  def getSupMaterial(self):
178
  pass
179
+
180
  def removeHeaders(self):
181
  pass
182
+
183
  def removeFooters(self):
184
  pass
185
+
186
  def removeReference(self):
187
  pass
NER/WordDoc/__pycache__/wordDoc.cpython-311.pyc CHANGED
Binary files a/NER/WordDoc/__pycache__/wordDoc.cpython-311.pyc and b/NER/WordDoc/__pycache__/wordDoc.cpython-311.pyc differ
 
NER/WordDoc/wordDoc.py CHANGED
@@ -81,6 +81,35 @@ class wordDoc(): # using python-docx
81
  tableData += '\n'
82
  json["Section" + str(s)]["Table"+str(i)] = tableData
83
  return json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def extractTableAsExcel(self):
85
  getDoc = ''
86
  try:
 
81
  tableData += '\n'
82
  json["Section" + str(s)]["Table"+str(i)] = tableData
83
  return json
84
+ def extractTableAsList(self):
85
+ tables = []
86
+ try:
87
+ doc = Document()
88
+ doc.LoadFromFile(self.wordDoc)
89
+ except:
90
+ response = requests.get(self.wordDoc)
91
+ name = self.wordDoc.split("/")[-1]
92
+ with open(os.path.join(self.saveFolder, name), "wb") as f:
93
+ f.write(response.content)
94
+ doc = Document()
95
+ doc.LoadFromFile(os.path.join(self.saveFolder, name))
96
+
97
+ for s in range(doc.Sections.Count):
98
+ section = doc.Sections.get_Item(s)
99
+ for i in range(section.Tables.Count):
100
+ table = section.Tables.get_Item(i)
101
+ table_data = []
102
+ for row in range(table.Rows.Count):
103
+ row_data = []
104
+ for cell in range(table.Rows.get_Item(row).Cells.Count):
105
+ cell_obj = table.Rows.get_Item(row).Cells.get_Item(cell)
106
+ cell_text = ""
107
+ for p in range(cell_obj.Paragraphs.Count):
108
+ cell_text += cell_obj.Paragraphs.get_Item(p).Text.strip() + " "
109
+ row_data.append(cell_text.strip())
110
+ table_data.append(row_data)
111
+ tables.append(table_data)
112
+ return tables
113
  def extractTableAsExcel(self):
114
  getDoc = ''
115
  try:
NER/html/__pycache__/extractHTML.cpython-311.pyc CHANGED
Binary files a/NER/html/__pycache__/extractHTML.cpython-311.pyc and b/NER/html/__pycache__/extractHTML.cpython-311.pyc differ
 
NER/html/extractHTML.py CHANGED
@@ -1,166 +1,222 @@
1
- #!pip install bs4
2
- # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
3
- from bs4 import BeautifulSoup
4
- import requests
5
- from DefaultPackages import openFile, saveFile
6
- from NER import cleanText
7
- import pandas as pd
8
- import os
9
- class HTML():
10
- def __init__(self, htmlFile, htmlLink):
11
- self.htmlLink = htmlLink
12
- self.htmlFile = htmlFile
13
- def openHTMLFile(self):
14
- try:
15
- if self.htmlLink != "None":
16
- r = requests.get(self.htmlLink)
17
- soup = BeautifulSoup(r.content, 'html.parser')
18
- else:
19
- with open(self.htmlFile) as fp:
20
- soup = BeautifulSoup(fp, 'html.parser')
21
- except:
22
- print("cannot open html file")
23
- soup = None
24
- return soup
25
- def getText(self):
26
- soup = self.openHTMLFile()
27
- text = ""
28
- if soup:
29
- s = soup.find_all("html")
30
- for t in range(len(s)):
31
- text = s[t].get_text()
32
- cl = cleanText.cleanGenText()
33
- text = cl.removeExtraSpaceBetweenWords(text)
34
- return text
35
- def getListSection(self, scienceDirect=None):
36
- json = {}
37
- text = ""
38
- textJson, textHTML = "",""
39
- if scienceDirect == None:
40
- soup = self.openHTMLFile()
41
- # get list of section
42
- json = {}
43
- if soup:
44
- for h2Pos in range(len(soup.find_all('h2'))):
45
- if soup.find_all('h2')[h2Pos].text not in json:
46
- json[soup.find_all('h2')[h2Pos].text] = []
47
- if h2Pos + 1 < len(soup.find_all('h2')):
48
- content = soup.find_all('h2')[h2Pos].find_next("p")
49
- nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
50
- while content.text != nexth2Content.text:
51
- json[soup.find_all('h2')[h2Pos].text].append(content.text)
52
- content = content.find_next("p")
53
- else:
54
- content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
55
- json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
56
- # format
57
- '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
58
- 'Results':[], 'Discussion':[], 'References':[],
59
- 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
60
- 'Additional information':[], 'Electronic supplementary material':[],
61
- 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
62
- if scienceDirect!= None or len(json)==0:
63
- # Replace with your actual Elsevier API key
64
- api_key = os.environ["SCIENCE_DIRECT_API"]
65
- # ScienceDirect article DOI or PI (Example DOI)
66
- doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
67
- # Base URL for the Elsevier API
68
- base_url = "https://api.elsevier.com/content/article/doi/"
69
- # Set headers with API key
70
- headers = {
71
- "Accept": "application/json",
72
- "X-ELS-APIKey": api_key
73
- }
74
- # Make the API request
75
- response = requests.get(base_url + doi, headers=headers)
76
- # Check if the request was successful
77
- if response.status_code == 200:
78
- data = response.json()
79
- supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
80
- if "originalText" in list(supp_data.keys()):
81
- if type(supp_data["originalText"])==str:
82
- json["originalText"] = [supp_data["originalText"]]
83
- if type(supp_data["originalText"])==dict:
84
- json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
85
- else:
86
- if type(supp_data)==dict:
87
- for key in supp_data:
88
- json[key] = [supp_data[key]]
89
- if json:
90
- textJson = self.mergeTextInJson(json)
91
- textHTML = self.getText()
92
- if len(textHTML) > len(textJson):
93
- text = textHTML
94
- else: text = textJson
95
- return text #json
96
- def getReference(self):
97
- # get reference to collect more next data
98
- ref = []
99
- json = self.getListSection()
100
- for key in json["References"]:
101
- ct = cleanText.cleanGenText(key)
102
- cleanText, filteredWord = ct.cleanText()
103
- if cleanText not in ref:
104
- ref.append(cleanText)
105
- return ref
106
- def getSupMaterial(self):
107
- # check if there is material or not
108
- json = {}
109
- soup = self.openHTMLFile()
110
- for h2Pos in range(len(soup.find_all('h2'))):
111
- if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
112
- #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
113
- link, output = [],[]
114
- if soup.find_all('h2')[h2Pos].text not in json:
115
- json[soup.find_all('h2')[h2Pos].text] = []
116
- for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
117
- link.append(l["href"])
118
- if h2Pos + 1 < len(soup.find_all('h2')):
119
- nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
120
- if nexth2Link in link:
121
- link = link[:link.index(nexth2Link)]
122
- # only take links having "https" in that
123
- for i in link:
124
- if "https" in i: output.append(i)
125
- json[soup.find_all('h2')[h2Pos].text].extend(output)
126
- return json
127
- def extractTable(self):
128
- soup = self.openHTMLFile()
129
- df = []
130
- try:
131
- df = pd.read_html(str(soup))
132
- except ValueError:
133
- df = []
134
- print("No tables found in HTML file")
135
- return df
136
- def mergeTextInJson(self,jsonHTML):
137
- cl = cleanText.cleanGenText()
138
- #cl = cleanGenText()
139
- htmlText = ""
140
- for sec in jsonHTML:
141
- # section is "\n\n"
142
- if len(jsonHTML[sec]) > 0:
143
- for i in range(len(jsonHTML[sec])):
144
- # same section is just a dot.
145
- text = jsonHTML[sec][i]
146
- if len(text)>0:
147
- #text = cl.removeTabWhiteSpaceNewLine(text)
148
- #text = cl.removeExtraSpaceBetweenWords(text)
149
- text, filteredWord = cl.textPreprocessing(text, keepPeriod=True)
150
- jsonHTML[sec][i] = text
151
- if i-1 >= 0:
152
- if len(jsonHTML[sec][i-1])>0:
153
- if jsonHTML[sec][i-1][-1] != ".":
154
- htmlText += ". "
155
- htmlText += jsonHTML[sec][i]
156
- if len(jsonHTML[sec][i]) > 0:
157
- if jsonHTML[sec][i][-1]!=".":
158
- htmlText += "."
159
- htmlText += "\n\n"
160
- return htmlText
161
- def removeHeaders(self):
162
- pass
163
- def removeFooters(self):
164
- pass
165
- def removeReferences(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  pass
 
1
+ # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
2
+ from bs4 import BeautifulSoup
3
+ import requests
4
+ from DefaultPackages import openFile, saveFile
5
+ from NER import cleanText
6
+ import pandas as pd
7
+ class HTML():
8
+ def __init__(self, htmlFile, htmlLink):
9
+ self.htmlLink = htmlLink
10
+ self.htmlFile = htmlFile
11
+ # def openHTMLFile(self):
12
+ # headers = {
13
+ # "User-Agent": (
14
+ # "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
15
+ # "AppleWebKit/537.36 (KHTML, like Gecko) "
16
+ # "Chrome/114.0.0.0 Safari/537.36"
17
+ # ),
18
+ # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
19
+ # "Referer": self.htmlLink,
20
+ # "Connection": "keep-alive"
21
+ # }
22
+
23
+ # session = requests.Session()
24
+ # session.headers.update(headers)
25
+
26
+ # if self.htmlLink != "None":
27
+ # try:
28
+ # r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
29
+ # if r.status_code != 200:
30
+ # print(f"❌ HTML GET failed: {r.status_code} — {self.htmlLink}")
31
+ # return BeautifulSoup("", 'html.parser')
32
+ # soup = BeautifulSoup(r.content, 'html.parser')
33
+ # except Exception as e:
34
+ # print(f"❌ Exception fetching HTML: {e}")
35
+ # return BeautifulSoup("", 'html.parser')
36
+ # else:
37
+ # with open(self.htmlFile) as fp:
38
+ # soup = BeautifulSoup(fp, 'html.parser')
39
+ # return soup
40
+ from lxml.etree import ParserError, XMLSyntaxError
41
+
42
+ def openHTMLFile(self):
43
+ not_need_domain = ['https://broadinstitute.github.io/picard/',
44
+ 'https://software.broadinstitute.org/gatk/best-practices/',
45
+ 'https://www.ncbi.nlm.nih.gov/genbank/',
46
+ 'https://www.mitomap.org/']
47
+ headers = {
48
+ "User-Agent": (
49
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
50
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
51
+ "Chrome/114.0.0.0 Safari/537.36"
52
+ ),
53
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
54
+ "Referer": self.htmlLink,
55
+ "Connection": "keep-alive"
56
+ }
57
+
58
+ session = requests.Session()
59
+ session.headers.update(headers)
60
+ if self.htmlLink in not_need_domain:
61
+ return BeautifulSoup("", 'html.parser')
62
+ try:
63
+ if self.htmlLink and self.htmlLink != "None":
64
+ r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
65
+ if r.status_code != 200 or not r.text.strip():
66
+ print(f"❌ HTML GET failed ({r.status_code}) or empty page: {self.htmlLink}")
67
+ return BeautifulSoup("", 'html.parser')
68
+ soup = BeautifulSoup(r.content, 'html.parser')
69
+ else:
70
+ with open(self.htmlFile, encoding='utf-8') as fp:
71
+ soup = BeautifulSoup(fp, 'html.parser')
72
+ except (ParserError, XMLSyntaxError, OSError) as e:
73
+ print(f"🚫 HTML parse error for {self.htmlLink}: {type(e).__name__}")
74
+ return BeautifulSoup("", 'html.parser')
75
+ except Exception as e:
76
+ print(f"❌ General exception for {self.htmlLink}: {e}")
77
+ return BeautifulSoup("", 'html.parser')
78
+
79
+ return soup
80
+
81
+ def getText(self):
82
+ soup = self.openHTMLFile()
83
+ s = soup.find_all("html")
84
+ text = ""
85
+ if s:
86
+ for t in range(len(s)):
87
+ text = s[t].get_text()
88
+ cl = cleanText.cleanGenText()
89
+ text = cl.removeExtraSpaceBetweenWords(text)
90
+ return text
91
+ def getListSection(self, scienceDirect=None):
92
+ json = {}
93
+ text = ""
94
+ textJson, textHTML = "",""
95
+ if scienceDirect == None:
96
+ soup = self.openHTMLFile()
97
+ # get list of section
98
+ json = {}
99
+ for h2Pos in range(len(soup.find_all('h2'))):
100
+ if soup.find_all('h2')[h2Pos].text not in json:
101
+ json[soup.find_all('h2')[h2Pos].text] = []
102
+ if h2Pos + 1 < len(soup.find_all('h2')):
103
+ content = soup.find_all('h2')[h2Pos].find_next("p")
104
+ nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
105
+ while content.text != nexth2Content.text:
106
+ json[soup.find_all('h2')[h2Pos].text].append(content.text)
107
+ content = content.find_next("p")
108
+ else:
109
+ content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
110
+ json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
111
+ # format
112
+ '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
113
+ 'Results':[], 'Discussion':[], 'References':[],
114
+ 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
115
+ 'Additional information':[], 'Electronic supplementary material':[],
116
+ 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
117
+ if scienceDirect!= None or len(json)==0:
118
+ # Replace with your actual Elsevier API key
119
+ api_key = "d0f25e6ae2b275e0d2b68e0e98f68d70"
120
+ # ScienceDirect article DOI or PI (Example DOI)
121
+ doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
122
+ # Base URL for the Elsevier API
123
+ base_url = "https://api.elsevier.com/content/article/doi/"
124
+ # Set headers with API key
125
+ headers = {
126
+ "Accept": "application/json",
127
+ "X-ELS-APIKey": api_key
128
+ }
129
+ # Make the API request
130
+ response = requests.get(base_url + doi, headers=headers)
131
+ # Check if the request was successful
132
+ if response.status_code == 200:
133
+ data = response.json()
134
+ supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
135
+ if "originalText" in list(supp_data.keys()):
136
+ if type(supp_data["originalText"])==str:
137
+ json["originalText"] = [supp_data["originalText"]]
138
+ if type(supp_data["originalText"])==dict:
139
+ json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
140
+ else:
141
+ if type(supp_data)==dict:
142
+ for key in supp_data:
143
+ json[key] = [supp_data[key]]
144
+
145
+ textJson = self.mergeTextInJson(json)
146
+ textHTML = self.getText()
147
+ if len(textHTML) > len(textJson):
148
+ text = textHTML
149
+ else: text = textJson
150
+ return text #json
151
+ def getReference(self):
152
+ # get reference to collect more next data
153
+ ref = []
154
+ json = self.getListSection()
155
+ for key in json["References"]:
156
+ ct = cleanText.cleanGenText(key)
157
+ cleanText, filteredWord = ct.cleanText()
158
+ if cleanText not in ref:
159
+ ref.append(cleanText)
160
+ return ref
161
+ def getSupMaterial(self):
162
+ # check if there is material or not
163
+ json = {}
164
+ soup = self.openHTMLFile()
165
+ for h2Pos in range(len(soup.find_all('h2'))):
166
+ if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
167
+ #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
168
+ link, output = [],[]
169
+ if soup.find_all('h2')[h2Pos].text not in json:
170
+ json[soup.find_all('h2')[h2Pos].text] = []
171
+ for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
172
+ link.append(l["href"])
173
+ if h2Pos + 1 < len(soup.find_all('h2')):
174
+ nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
175
+ if nexth2Link in link:
176
+ link = link[:link.index(nexth2Link)]
177
+ # only take links having "https" in that
178
+ for i in link:
179
+ if "https" in i: output.append(i)
180
+ json[soup.find_all('h2')[h2Pos].text].extend(output)
181
+ return json
182
+ def extractTable(self):
183
+ soup = self.openHTMLFile()
184
+ df = []
185
+ if len(soup)>0:
186
+ try:
187
+ df = pd.read_html(str(soup))
188
+ except ValueError:
189
+ df = []
190
+ print("No tables found in HTML file")
191
+ return df
192
+ def mergeTextInJson(self,jsonHTML):
193
+ cl = cleanText.cleanGenText()
194
+ #cl = cleanGenText()
195
+ htmlText = ""
196
+ for sec in jsonHTML:
197
+ # section is "\n\n"
198
+ if len(jsonHTML[sec]) > 0:
199
+ for i in range(len(jsonHTML[sec])):
200
+ # same section is just a dot.
201
+ text = jsonHTML[sec][i]
202
+ if len(text)>0:
203
+ #text = cl.removeTabWhiteSpaceNewLine(text)
204
+ #text = cl.removeExtraSpaceBetweenWords(text)
205
+ text, filteredWord = cl.textPreprocessing(text, keepPeriod=True)
206
+ jsonHTML[sec][i] = text
207
+ if i-1 >= 0:
208
+ if len(jsonHTML[sec][i-1])>0:
209
+ if jsonHTML[sec][i-1][-1] != ".":
210
+ htmlText += ". "
211
+ htmlText += jsonHTML[sec][i]
212
+ if len(jsonHTML[sec][i]) > 0:
213
+ if jsonHTML[sec][i][-1]!=".":
214
+ htmlText += "."
215
+ htmlText += "\n\n"
216
+ return htmlText
217
+ def removeHeaders(self):
218
+ pass
219
+ def removeFooters(self):
220
+ pass
221
+ def removeReferences(self):
222
  pass
NER/word2Vec/__pycache__/word2vec.cpython-311.pyc CHANGED
Binary files a/NER/word2Vec/__pycache__/word2vec.cpython-311.pyc and b/NER/word2Vec/__pycache__/word2vec.cpython-311.pyc differ
 
NER/word2Vec/testModel/test_model.model CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:734185116a1d2099dba0d04efc0eb1b7e0e8213fe1259b57bbcb7aaac3cd46ea
3
- size 133
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:193f58915e5f895c3d00d1012a691b15ad051d2b9eaf83662a2a7e3af326e923
3
+ size 25214
NER/word2Vec/testModel/test_model_updated.model CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b1b785c79991b857b364ee9863985eaf845087efb1aa40a6b9cfae3b2a50012
3
- size 133
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78fc82c69afe5c74ae631f389025e5195ad613b62a7c5e42fb7f3f0b7cda99c3
3
+ size 30688
NER/word2Vec/word2vec.py CHANGED
@@ -1,369 +1,436 @@
1
- '''WORD TO VECTOR'''
2
- import pandas as pd
3
- import json
4
- import gensim
5
- import spacy
6
- from DefaultPackages import openFile, saveFile
7
- from NER import cleanText
8
- from gensim.models.keyedvectors import KeyedVectors
9
- from gensim.test.utils import common_texts
10
- from gensim.models.word2vec import Word2Vec
11
- from gensim.scripts.glove2word2vec import glove2word2vec
12
- from gensim.test.utils import datapath, get_tmpfile
13
- import sys
14
- import subprocess
15
- # can try multiprocessing to run quicker
16
- import multiprocessing
17
- import copy
18
- sys.setrecursionlimit(1000)
19
- # creat folder word2Vec
20
- #! mkdir /content/drive/MyDrive/CollectData/NER/word2Vec
21
- # create word2vec model
22
- #model = KeyedVectors.load_word2vec_format('/content/drive/MyDrive/CollectData/NER/word2Vec', binary=True)
23
- '''Some notes for this model
24
- sometimes when we do the corpus, there are some adverbs which are unnecessary but might be seen as
25
- a similar word to the word we are finding, so can we try to preprocess text so that
26
- we make the corpus more effective and only contains the important words. Then when we
27
- train the model, the important words will be seen as important. Or
28
- when we already have the similar list of words, we can remove the words in there
29
- that are stopwords/unnecessary words.'''
30
- ### For more complex analysis, consider using sentence embedding models like "Doc2Vec" to represent the meaning of entire sentences instead of just individual words
31
- class word2Vec():
32
- def __init__(self, nameFile=None, modelName=None):
33
- self.nameFile = nameFile
34
- self.modelName = modelName
35
- def spacy_similarity(self, word):
36
- # when use word2vec, try medium or large is better
37
- # maybe try odc similarity?
38
- nlp = spacy.load("en_core_web_lg")
39
- doc = nlp(word)
40
- for token1 in doc:
41
- for token2 in doc:
42
- print(token1.text, token2.text, token1.similarity(token2))
43
- pass
44
- # clean text before transform to corpus
45
- def cleanTextBeforeCorpus(self,oriText, doi=None):
46
- cl = cleanText.cleanGenText()
47
- #cl = cleanGenText()
48
- output = ""
49
- alreadyRemoveDoi = False
50
- for word in oriText.split(" "):
51
- # remove DOI
52
- if doi != None and doi in oriText:
53
- if alreadyRemoveDoi == False:
54
- newWord = cl.removeDOI(word,doi)
55
- if len(newWord) > 0 and newWord != word:
56
- alreadyRemoveDoi = True
57
- word = newWord
58
- # remove punctuation
59
- # split the sticked words
60
- #word = cl.splitStickWords(word)
61
- # remove punctuation
62
- word = cl.removePunct(word,True)
63
- # remove URL
64
- word = cl.removeURL(word)
65
- # remove HTMLTag
66
- word = cl.removeHTMLTag(word)
67
- # remove tab, white space, newline
68
- word = cl.removeTabWhiteSpaceNewLine(word)
69
- # optional: remove stopwords
70
- #word = cl.removeStopWords(word)
71
- if len(word)>0:
72
- output += word + " "
73
- return output
74
- def cleanAllTextBeforeCorpus(self, allText, doi=None):
75
- cleanOutput = ""
76
- remove = "Evaluation Warning: The document was created with Spire.Doc for Python."
77
- if len(allText) > 0:
78
- corpusText = allText
79
- for pos in range(len(corpusText.split("\n\n"))):
80
- if len(corpusText.split("\n\n")[pos]) > 0:
81
- lines = corpusText.split("\n\n")[pos]
82
- for line in lines.split("\n"):
83
- if remove in line: line = line.replace(remove, "")
84
- clean_text = self.cleanTextBeforeCorpus(line, doi)
85
- cleanOutput += clean_text + "\n"
86
- cleanOutput += "\n\n"
87
- return cleanOutput
88
- def tableTransformToCorpusText(self, df, excelFile=None):
89
- # PDF, Excel, WordDoc
90
- #cl = cleanText.cleanGenText()
91
- corpus = {}
92
- # PDF or df
93
- if excelFile == None:
94
- if len(df) > 0:
95
- try:
96
- for i in range(len(df)):
97
- # each new dimension/page is considered to be a sentence which ends with the period.
98
- # each new line is a new list, and each new df is a new corpus
99
- outputDF = []
100
- text = df[i].values.tolist()
101
- if len(text) > 0:
102
- outputRowDF = self.helperRowTableToCorpus(text)
103
- #outputColDF = self.helperColTableToCorpus(text)
104
- outputDF.extend(outputRowDF)
105
- #outputDF.extend(outputColDF)
106
- if len(outputDF) > 0:
107
- corpus["corpus" + str(i)] = outputDF
108
- except:
109
- outputDF = []
110
- text = df.values.tolist()
111
- if len(text) > 0:
112
- outputRowDF = self.helperRowTableToCorpus(text)
113
- #outputColDF = self.helperColTableToCorpus(text)
114
- outputDF.extend(outputRowDF)
115
- #outputDF.extend(outputColDF)
116
- if len(outputDF) > 0:
117
- corpus["corpus0"] = outputDF
118
- else:
119
- try:
120
- df = pd.ExcelFile(excelFile)
121
- except:
122
- if filepath.endswith('.xls'):
123
- df = pd.read_excel(filepath, engine='xlrd')
124
- else:
125
- df = pd.read_excel(filepath, engine='openpyxl')
126
- sheetNames = df.sheet_names
127
- output = []
128
- if len(sheetNames) > 0:
129
- for s in range(len(sheetNames)):
130
- outputDF = []
131
- with pd.ExcelFile(excelFile) as xls:
132
- data = pd.read_excel(xls, sheetNames[s])
133
- if sheetNames[s] != 'Evaluation Warning':
134
- text = data.values.tolist()
135
- if len(text) > 0:
136
- outputRowDF = self.helperRowTableToCorpus(text)
137
- #outputColDF = self.helperColTableToCorpus(text)
138
- outputDF.extend(outputRowDF)
139
- #outputDF.extend(outputColDF)
140
- if len(outputDF) > 0:
141
- corpus["corpus" + str(s)] = outputDF
142
- return corpus
143
- def helperRowTableToCorpus(self, textList):
144
- #cl = cleanGenText()
145
- cl = cleanText.cleanGenText()
146
- stopWords = ["NaN","Unnamed:","nan"]
147
- outputDF = []
148
- for line in textList:
149
- outputLine = []
150
- for words in line:
151
- words = str(words)
152
- if len(words) > 0:
153
- for word in words.split(" "):
154
- # remove specific stopwords for table: "NaN", "Unnamed: 0", row index: if the number appears first, it's just a row index; keep "KM1"
155
- if str(word) not in stopWords: # remove "NaN", "Unnamed:","nan"
156
- #word = cl.splitStickWords(word)
157
- word = cl.removePunct(word)
158
- word = " ".join(cl.removeStopWords(word))
159
- word = cl.removeTabWhiteSpaceNewLine(word)
160
- if len(word) > 1:
161
- if len(word.split(" ")) > 1:
162
- for x in word.split(" "):
163
- if len(x) > 1 and x.isnumeric()==False:
164
- outputLine.append(x.lower())
165
- else:
166
- if word.isnumeric() == False:
167
- outputLine.append(word.lower())
168
- if len(outputLine) > 0:
169
- outputDF.append(outputLine)
170
- return outputDF
171
- def helperColTableToCorpus(self, dfList):
172
- #cl = cleanGenText()
173
- cl = cleanText.cleanGenText()
174
- stopWords = ["NaN","Unnamed:","nan"]
175
- outputDF = []
176
- # use the first length line as the column ref
177
- for pos in range(len(dfList[0])):
178
- outputLine = []
179
- for line in dfList:
180
- if pos < len(line):
181
- words = line[pos]
182
- words = str(words)
183
- else: words = ""
184
- if len(words) > 0:
185
- for word in words.split(" "):
186
- # remove specific stopwords for table: "NaN", "Unnamed: 0", row index: if the number appears first, it's just a row index; keep "KM1"
187
- if str(word) not in stopWords: # remove "NaN", "Unnamed:","nan"
188
- #word = cl.splitStickWords(word)
189
- word = cl.removePunct(word)
190
- word = " ".join(cl.removeStopWords(word))
191
- word = cl.removeTabWhiteSpaceNewLine(word)
192
- if len(word) > 1:
193
- if len(word.split(" ")) > 1:
194
- for x in word.split(" "):
195
- if len(x) > 1 and x.isnumeric()==False:
196
- outputLine.append(x.lower())
197
- else:
198
- if word.isnumeric() == False:
199
- outputLine.append(word.lower())
200
- if len(outputLine) > 0:
201
- outputDF.append(outputLine)
202
- return outputDF
203
- # create a corpus
204
- def createCorpusText(self, corpusText):
205
- '''ex: "Tom is cat. Jerry is mouse."
206
- corpus = [["Tom", "is", "cat"], ["Jerry", "is", "mouse"]]'''
207
- # the output should be like this:
208
- '''texts = {
209
- "Paragraph 1": [["Cat", "is", "an","animal], ["Tom", "is", "cat"]],
210
- "Paragraph 2": [["Mouse", "is", "an", "animal"], ["Jerry", "is", "mouse"]]
211
- }
212
- '''
213
- # separate paragraph
214
- '''Ex: Cat is an animal. Tom is cat.
215
-
216
- Mouse is an animal.
217
- Jerry is mouse.'''
218
- texts = {}
219
- cl = cleanText.cleanGenText()
220
- #cl = cleanGenText()
221
- for pos in range(len(corpusText.split("\n\n"))):
222
- if len(corpusText.split("\n\n")[pos]) > 0:
223
- texts["Paragraph "+str(pos)] = []
224
- lines = corpusText.split("\n\n")[pos]
225
- for line in lines.split("\n"):
226
- for l in line.split("."):
227
- if len(l) > 0:
228
- cl.removeTabWhiteSpaceNewLine(l)
229
- l = l.lower()
230
- newL = []
231
- for word in l.split(" "):
232
- if len(word) > 0:
233
- word = cl.removeStopWords(word)
234
- for w in word:
235
- if len(w) > 0 and w.isnumeric()==False:
236
- newL.append(w)
237
- if len(newL)>0:
238
- texts["Paragraph "+str(pos)].append(newL)
239
- if len(texts["Paragraph "+str(pos)]) == 0:
240
- del texts["Paragraph "+str(pos)]
241
- return texts
242
- def selectParaForWC(self,corpus):
243
- ''' corpus should be in the format:
244
- corpus = [["Tom", "is", "cat"], ["Jerry", "is", "mouse"]]'''
245
- corSize, window, vector_size, sample, negative, epochs, sg = None, None, None, None, None, None, None
246
- corSize = len(corpus)
247
- # less than 2000
248
- if 0 < corSize < 2000:
249
- window=3.5
250
- vector_size=75
251
- sample=1e-3
252
- negative=10
253
- epochs=10
254
- sg=1
255
- # 2000 - 100000
256
- elif 2000 <= corSize < 100000:
257
- window=3.5
258
- vector_size=75
259
- sample=1e-5
260
- negative=10
261
- epochs=10
262
- sg=1
263
- elif 100000 <=corSize < 1000000:
264
- window=7.5
265
- vector_size=150
266
- sample=1e-5
267
- negative=10
268
- epochs=6
269
- sg=0
270
- return window, vector_size, sample, negative, epochs, sg
271
- def trainWord2Vec(self,nameFile,modelName,saveFolder,window=3.5,
272
- vector_size=75,sample=1e-3,negative=10,epochs=10,sg=1):
273
- # if you dont have backup file, you can use again the nameFile just to increase the lenght of corpus
274
- jsonFile = ""
275
- jsonFile = openFile.openJsonFile(nameFile) # this is a corpus json file from an article
276
- cores = multiprocessing.cpu_count()
277
- combinedCorpus = []
278
- window, vector_size, sample, negative, epochs, sg = None, None, None, None, None, None
279
- if len(jsonFile) > 0:
280
- for key in jsonFile:
281
- combinedCorpus.extend(jsonFile[key])
282
- window, vector_size, sample, negative, epochs, sg = self.selectParaForWC(combinedCorpus)
283
- # # min_count=1 ensures all words are included
284
- '''w2vModel = Word2Vec(
285
- min_count=1,
286
- window=window,
287
- vector_size=vector_size,
288
- sample=sample,
289
- alpha=0.03,
290
- min_alpha=0.0007,
291
- negative=negative,
292
- workers=cores-1,
293
- epochs = epochs,
294
- sg=sg)'''
295
- #w2vModel = Word2Vec(vector_size=150, window=10, min_count=1, workers=4)
296
- accept = False
297
- while not accept:
298
- if window!=None and vector_size!=None and sample!=None and negative!=None and epochs!=None and sg!=None:
299
- try:
300
- w2vModel = Word2Vec(
301
- min_count=1,
302
- window=window,
303
- vector_size=vector_size,
304
- sample=sample,
305
- alpha=0.03,
306
- min_alpha=0.0007,
307
- negative=negative,
308
- workers=cores-1,
309
- epochs = epochs,
310
- sg=sg)
311
- w2vModel.build_vocab(combinedCorpus)
312
- w2vModel.train(combinedCorpus, total_examples=w2vModel.corpus_count, epochs=30)
313
- accept = True
314
- except:
315
- for key in jsonFile:
316
- combinedCorpus.extend(jsonFile[key])
317
- window, vector_size, sample, negative, epochs, sg = self.selectParaForWC(combinedCorpus)
318
- print("next is " + str(len(combinedCorpus)))
319
- else:
320
- print("no parameter to train")
321
- break
322
- #w2vModel.build_vocab(combinedCorpus)
323
- #w2vModel.train(combinedCorpus, total_examples=w2vModel.corpus_count, epochs=30)
324
- #w2vModel.save("/content/drive/MyDrive/CollectData/NER/word2Vec/TestExamples/models/wordVector_"+modelName+".model")
325
- #w2vModel.wv.save_word2vec_format("/content/drive/MyDrive/CollectData/NER/word2Vec/TestExamples/models/wordVector_"+modelName+".txt")
326
- w2vModel.save(saveFolder+"/"+modelName+".model")
327
- w2vModel.wv.save_word2vec_format(saveFolder+"/"+modelName+".txt")
328
- print("done w2v")
329
- else: print("no corpus to train")
330
- #return combinedCorpus
331
- def genSimilar(self,word,modelFile,n=10, cos_thres=0.7):
332
- # might not be a meaningful keyword
333
- #stopWords = ["show"]
334
- # same word but just plural nouns, tense
335
- simWords = [word+"s",word+"es",word+"ing",word+"ed"]
336
- model = KeyedVectors.load_word2vec_format(modelFile, binary = False) # model file in format txt
337
- results = model.most_similar(positive=[word],topn=n)
338
- #removeIndex = []
339
- #currN = copy.deepcopy(n)
340
- '''for r in range(len(results)):
341
- if len(results[r][0]) < 2:
342
- removeIndex.append(results[r])
343
- # remove the same word but just plural and singular noun and lower than the cos_thres
344
- elif results[r][0] == word:
345
- removeIndex.append(results[r])
346
- elif results[r][0] in simWords or float(results[r][1]) < cos_thres or results[r][0] in stopWords:
347
- removeIndex.append(results[r])
348
- for rem in removeIndex:
349
- results.remove(rem)
350
- while len(results)!=n and len(results) != 0:
351
- moreNewResult = model.most_similar(positive=[word],topn=currN+1)[-1]
352
- if moreNewResult not in results and len(moreNewResult[0])>1:
353
- if moreNewResult[0] not in stopWords and results[0] != word:
354
- results.append(moreNewResult)
355
- currN +=1'''
356
- return results
357
- # adding our model into spacy
358
- # this deals with command line; but instead of using it, we write python script to run command line
359
- def loadWordVec(self,modelName,wordVec):
360
- # modelName is the name you want to save into spacy
361
- # wordVec is the trained word2vec in txt format
362
- subprocess.run([sys.executable,
363
- "-m",
364
- "spacy",
365
- "init-model",
366
- "en",
367
- modelName, # this modelName comes from the saved modelName of function trainWord2Vec
368
- "--vectors-loc",
369
- wordVec])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''WORD TO VECTOR'''
2
+ import pandas as pd
3
+ import json
4
+ import gensim
5
+ import spacy
6
+ from DefaultPackages import openFile, saveFile
7
+ from NER import cleanText
8
+ from gensim.models.keyedvectors import KeyedVectors
9
+ from gensim.test.utils import common_texts
10
+ from gensim.models.word2vec import Word2Vec
11
+ from gensim.scripts.glove2word2vec import glove2word2vec
12
+ from gensim.test.utils import datapath, get_tmpfile
13
+ from gensim.models import Phrases
14
+ from gensim.models.phrases import Phraser
15
+ import sys
16
+ import subprocess
17
+ import os
18
+ # can try multiprocessing to run quicker
19
+ import multiprocessing
20
+ import copy
21
+ sys.setrecursionlimit(1000)
22
+ # creat folder word2Vec
23
+ #! mkdir /content/drive/MyDrive/CollectData/NER/word2Vec
24
+ # create word2vec model
25
+ #model = KeyedVectors.load_word2vec_format('/content/drive/MyDrive/CollectData/NER/word2Vec', binary=True)
26
+ '''Some notes for this model
27
+ sometimes when we do the corpus, there are some adverbs which are unnecessary but might be seen as
28
+ a similar word to the word we are finding, so can we try to preprocess text so that
29
+ we make the corpus more effective and only contains the important words. Then when we
30
+ train the model, the important words will be seen as important. Or
31
+ when we already have the similar list of words, we can remove the words in there
32
+ that are stopwords/unnecessary words.'''
33
+ ### For more complex analysis, consider using sentence embedding models like "Doc2Vec" to represent the meaning of entire sentences instead of just individual words
34
+ class word2Vec():
35
+ def __init__(self, nameFile=None, modelName=None):
36
+ self.nameFile = nameFile
37
+ self.modelName = modelName
38
+ #self.nlp = spacy.load("en_core_web_lg")
39
+ self.cl = cleanText.cleanGenText()
40
+ def spacy_similarity(self, word):
41
+ # when use word2vec, try medium or large is better
42
+ # maybe try odc similarity?
43
+ doc = self.nlp(word)
44
+ for token1 in doc:
45
+ for token2 in doc:
46
+ print(token1.text, token2.text, token1.similarity(token2))
47
+ pass
48
+ # clean text before transform to corpus
49
+ def cleanTextBeforeCorpus(self,oriText, doi=None):
50
+ #cl = cleanText.cleanGenText()
51
+ #cl = cleanGenText()
52
+ output = ""
53
+ alreadyRemoveDoi = False
54
+ for word in oriText.split(" "):
55
+ # remove DOI
56
+ if doi != None and doi in oriText:
57
+ if alreadyRemoveDoi == False:
58
+ newWord = self.cl.removeDOI(word,doi)
59
+ if len(newWord) > 0 and newWord != word:
60
+ alreadyRemoveDoi = True
61
+ word = newWord
62
+ # remove punctuation
63
+ # split the sticked words
64
+ #word = cl.splitStickWords(word)
65
+ # remove punctuation
66
+ word = self.cl.removePunct(word,True)
67
+ # remove URL
68
+ word = self.cl.removeURL(word)
69
+ # remove HTMLTag
70
+ word = self.cl.removeHTMLTag(word)
71
+ # remove tab, white space, newline
72
+ word = self.cl.removeTabWhiteSpaceNewLine(word)
73
+ # optional: remove stopwords
74
+ #word = cl.removeStopWords(word)
75
+ if len(word)>0:
76
+ output += word + " "
77
+ return output
78
+ def cleanAllTextBeforeCorpus(self, allText, doi=None):
79
+ cleanOutput = ""
80
+ remove = "Evaluation Warning: The document was created with Spire.Doc for Python."
81
+ if len(allText) > 0:
82
+ corpusText = allText.split("\n\n")
83
+ for pos in range(len(corpusText)):
84
+ lines = corpusText[pos]
85
+ if len(lines) > 0:
86
+ for line in lines.split("\n"):
87
+ if remove in line: line = line.replace(remove, "")
88
+ clean_text = self.cleanTextBeforeCorpus(line, doi)
89
+ cleanOutput += clean_text + "\n"
90
+ cleanOutput += "\n\n"
91
+ return cleanOutput
92
+ import urllib.parse, requests
93
+
94
+ def tableTransformToCorpusText(self, df, excelFile=None):
95
+ # PDF, Excel, WordDoc
96
+ #cl = cleanText.cleanGenText()
97
+ corpus = {}
98
+ # PDF or df
99
+ if excelFile == None:
100
+ if len(df) > 0:
101
+ try:
102
+ for i in range(len(df)):
103
+ # each new dimension/page is considered to be a sentence which ends with the period.
104
+ # each new line is a new list, and each new df is a new corpus
105
+ outputDF = []
106
+ text = df[i].values.tolist()
107
+ if len(text) > 0:
108
+ outputRowDF = self.helperRowTableToCorpus(text)
109
+ #outputColDF = self.helperColTableToCorpus(text)
110
+ outputDF.extend(outputRowDF)
111
+ #outputDF.extend(outputColDF)
112
+ if len(outputDF) > 0:
113
+ corpus["corpus" + str(i)] = outputDF
114
+ except:
115
+ outputDF = []
116
+ text = df.values.tolist()
117
+ if len(text) > 0:
118
+ outputRowDF = self.helperRowTableToCorpus(text)
119
+ #outputColDF = self.helperColTableToCorpus(text)
120
+ outputDF.extend(outputRowDF)
121
+ #outputDF.extend(outputColDF)
122
+ if len(outputDF) > 0:
123
+ corpus["corpus0"] = outputDF
124
+ else:
125
+ try:
126
+ df = pd.ExcelFile(excelFile)
127
+ except:
128
+ if excelFile.endswith('.xls'):
129
+ df = pd.read_excel(excelFile, engine='xlrd')
130
+ else:
131
+ df = pd.read_excel(excelFile, engine='openpyxl')
132
+ sheetNames = df.sheet_names
133
+ output = []
134
+ if len(sheetNames) > 0:
135
+ for s in range(len(sheetNames)):
136
+ outputDF = []
137
+ with pd.ExcelFile(excelFile) as xls:
138
+ data = pd.read_excel(xls, sheetNames[s])
139
+ if sheetNames[s] != 'Evaluation Warning':
140
+ text = data.values.tolist()
141
+ if len(text) > 0:
142
+ outputRowDF = self.helperRowTableToCorpus(text)
143
+ #outputColDF = self.helperColTableToCorpus(text)
144
+ outputDF.extend(outputRowDF)
145
+ #outputDF.extend(outputColDF)
146
+ if len(outputDF) > 0:
147
+ corpus["corpus" + str(s)] = outputDF
148
+ return corpus
149
+ def helperRowTableToCorpus(self, textList):
150
+ #cl = cleanGenText()
151
+ #cl = cleanText.cleanGenText()
152
+ stopWords = ["NaN","Unnamed:","nan"]
153
+ outputDF = []
154
+ for line in textList:
155
+ outputLine = []
156
+ for words in line:
157
+ words = str(words)
158
+ if len(words) > 0:
159
+ for word in words.split(" "):
160
+ # remove specific stopwords for table: "NaN", "Unnamed: 0", row index: if the number appears first, it's just a row index; keep "KM1"
161
+ if str(word) not in stopWords: # remove "NaN", "Unnamed:","nan"
162
+ #word = cl.splitStickWords(word)
163
+ word = self.cl.removePunct(word)
164
+ word = " ".join(self.cl.removeStopWords(word))
165
+ word = self.cl.removeTabWhiteSpaceNewLine(word)
166
+ if len(word) > 1:
167
+ if len(word.split(" ")) > 1:
168
+ for x in word.split(" "):
169
+ if len(x) > 1 and x.isnumeric()==False:
170
+ outputLine.append(x.lower())
171
+ else:
172
+ if word.isnumeric() == False:
173
+ outputLine.append(word.lower())
174
+ if len(outputLine) > 0:
175
+ outputDF.append(outputLine)
176
+ return outputDF
177
+ def helperColTableToCorpus(self, dfList):
178
+ #cl = cleanGenText()
179
+ #cl = cleanText.cleanGenText()
180
+ stopWords = ["NaN","Unnamed:","nan"]
181
+ outputDF = []
182
+ # use the first length line as the column ref
183
+ for pos in range(len(dfList[0])):
184
+ outputLine = []
185
+ for line in dfList:
186
+ if pos < len(line):
187
+ words = line[pos]
188
+ words = str(words)
189
+ else: words = ""
190
+ if len(words) > 0:
191
+ for word in words.split(" "):
192
+ # remove specific stopwords for table: "NaN", "Unnamed: 0", row index: if the number appears first, it's just a row index; keep "KM1"
193
+ if str(word) not in stopWords: # remove "NaN", "Unnamed:","nan"
194
+ #word = cl.splitStickWords(word)
195
+ word = self.cl.removePunct(word)
196
+ word = " ".join(self.cl.removeStopWords(word))
197
+ word = self.cl.removeTabWhiteSpaceNewLine(word)
198
+ if len(word) > 1:
199
+ if len(word.split(" ")) > 1:
200
+ for x in word.split(" "):
201
+ if len(x) > 1 and x.isnumeric()==False:
202
+ outputLine.append(x.lower())
203
+ else:
204
+ if word.isnumeric() == False:
205
+ outputLine.append(word.lower())
206
+ if len(outputLine) > 0:
207
+ outputDF.append(outputLine)
208
+ return outputDF
209
+ # create a corpus
210
+ def createCorpusText(self, corpusText):
211
+ '''ex: "Tom is cat. Jerry is mouse."
212
+ corpus = [["Tom", "is", "cat"], ["Jerry", "is", "mouse"]]'''
213
+ # the output should be like this:
214
+ '''texts = {
215
+ "Paragraph 1": [["Cat", "is", "an","animal], ["Tom", "is", "cat"]],
216
+ "Paragraph 2": [["Mouse", "is", "an", "animal"], ["Jerry", "is", "mouse"]]
217
+ }
218
+ '''
219
+ # separate paragraph
220
+ '''Ex: Cat is an animal. Tom is cat.
221
+
222
+ Mouse is an animal.
223
+ Jerry is mouse.'''
224
+ texts = {}
225
+ #cl = cleanText.cleanGenText()
226
+ #cl = cleanGenText()
227
+ corpus = corpusText.split("\n\n")
228
+ for pos in range(len(corpus)):
229
+ if len(corpus[pos]) > 0:
230
+ texts["Paragraph "+str(pos)] = []
231
+ lines = corpus[pos]
232
+ for line in lines.split("\n"):
233
+ for l in line.split("."):
234
+ if len(l) > 0:
235
+ l = self.cl.removeTabWhiteSpaceNewLine(l)
236
+ l = l.lower()
237
+ newL = []
238
+ for word in l.split(" "):
239
+ if len(word) > 0:
240
+ word = self.cl.removeStopWords(word)
241
+ for w in word:
242
+ if len(w) > 0 and w.isnumeric()==False:
243
+ newL.append(w)
244
+ if len(newL)>0:
245
+ texts["Paragraph "+str(pos)].append(newL)
246
+ if len(texts["Paragraph "+str(pos)]) == 0:
247
+ del texts["Paragraph "+str(pos)]
248
+ return texts
249
+
250
+ def selectParaForWC(self, corpus):
251
+ """
252
+ corpus = [["Tom", "is", "cat"], ["Jerry", "is", "mouse"]]
253
+ Heuristically determine Word2Vec parameters.
254
+ """
255
+ corSize = len(corpus)
256
+
257
+ if corSize == 0:
258
+ return None, None, None, None, None, None
259
+
260
+ # Adjust parameters based on corpus size
261
+ if corSize < 2000:
262
+ # Small corpus — need high generalization
263
+ window = 3
264
+ vector_size = 100
265
+ sample = 1e-3
266
+ negative = 5
267
+ epochs = 20
268
+ sg = 1 # Skip-gram preferred for rare words
269
+ elif corSize < 10000:
270
+ window = 5
271
+ vector_size = 150
272
+ sample = 1e-4
273
+ negative = 10
274
+ epochs = 20
275
+ sg = 1
276
+ elif corSize < 100000:
277
+ window = 7
278
+ vector_size = 200
279
+ sample = 1e-5
280
+ negative = 15
281
+ epochs = 15
282
+ sg = 1
283
+ elif corSize < 500000:
284
+ window = 10
285
+ vector_size = 250
286
+ sample = 1e-5
287
+ negative = 15
288
+ epochs = 10
289
+ sg = 0 # CBOW is okay when data is large
290
+ else:
291
+ # Very large corpus
292
+ window = 12
293
+ vector_size = 300
294
+ sample = 1e-6
295
+ negative = 20
296
+ epochs = 5
297
+ sg = 0
298
+
299
+ return window, vector_size, sample, negative, epochs, sg
300
+
301
+
302
+ def trainWord2Vec(self,nameFile,modelName,saveFolder,window=None,
303
+ vector_size=None,sample=None,negative=None,epochs=None,sg=None):
304
+ jsonFile = ""
305
+ jsonFile = openFile.openJsonFile(nameFile) # this is a corpus json file from an article
306
+ if not jsonFile:
307
+ print("No corpus to train")
308
+ return
309
+ cores = multiprocessing.cpu_count()
310
+ combinedCorpus = []
311
+ for key in jsonFile:
312
+ combinedCorpus.extend(jsonFile[key])
313
+ # detect phrase before choosing parameters
314
+ phrases = Phrases(combinedCorpus, min_count=2, threshold=10)
315
+ bigram = Phraser(phrases)
316
+ combinedCorpus = [bigram[sent] for sent in combinedCorpus]
317
+
318
+ if window==None and vector_size==None and sample==None and negative==None and epochs==None and sg==None:
319
+ window, vector_size, sample, negative, epochs, sg = self.selectParaForWC(combinedCorpus)
320
+ # # min_count=1 ensures all words are included
321
+ #w2vModel = Word2Vec(vector_size=150, window=10, min_count=1, workers=4)
322
+ accept = False
323
+ # add retry limit because if training keeps failing (bad corpus or corrupted input), it’ll keep retrying without limit.
324
+ retries = 0
325
+ while not accept and retries < 3:
326
+ if window!=None and vector_size!=None and sample!=None and negative!=None and epochs!=None and sg!=None:
327
+ try:
328
+ w2vModel = Word2Vec(
329
+ min_count=1,
330
+ window=window,
331
+ vector_size=vector_size,
332
+ sample=sample,
333
+ alpha=0.03,
334
+ min_alpha=0.0007,
335
+ negative=negative,
336
+ workers=cores-1,
337
+ epochs = epochs,
338
+ sg=sg)
339
+ w2vModel.build_vocab(combinedCorpus)
340
+ w2vModel.train(combinedCorpus, total_examples=w2vModel.corpus_count, epochs=epochs)
341
+ accept = True
342
+ except Exception as e:
343
+ print(f"Retry #{retries+1} failed: {e}")
344
+ retries +=1
345
+ else:
346
+ print("no parameter to train")
347
+ break
348
+ #w2vModel.build_vocab(combinedCorpus)
349
+ #w2vModel.train(combinedCorpus, total_examples=w2vModel.corpus_count, epochs=30)
350
+ #w2vModel.save("/content/drive/MyDrive/CollectData/NER/word2Vec/TestExamples/models/wordVector_"+modelName+".model")
351
+ #w2vModel.wv.save_word2vec_format("/content/drive/MyDrive/CollectData/NER/word2Vec/TestExamples/models/wordVector_"+modelName+".txt")
352
+ w2vModel.save(saveFolder+"/"+modelName+".model")
353
+ w2vModel.wv.save_word2vec_format(saveFolder+"/"+modelName+".txt")
354
+ print("done w2v")
355
+ #return combinedCorpus
356
+ def updateWord2Vec(self, modelPath, newCorpus, saveFolder=None):
357
+ if not newCorpus:
358
+ raise ValueError("New corpus is empty!")
359
+
360
+ model = Word2Vec.load(modelPath)
361
+
362
+ # Phrase detection on new data
363
+ phrases = Phrases(newCorpus, min_count=2, threshold=10)
364
+ bigram = Phraser(phrases)
365
+ newCorpus = [bigram[sent] for sent in newCorpus]
366
+
367
+ # Update vocab & retrain
368
+ model.build_vocab(newCorpus, update=True)
369
+ model.train(newCorpus, total_examples=len(newCorpus), epochs=model.epochs)
370
+
371
+ def genSimilar(self,word,modelFile,n=10, cos_thres=0.7):
372
+ # might not be a meaningful keyword
373
+ #stopWords = ["show"]
374
+ # same word but just plural nouns, tense
375
+ simWords = [word+"s",word+"es",word+"ing",word+"ed"]
376
+ model = KeyedVectors.load_word2vec_format(modelFile, binary = False) # model file in format txt
377
+ results = model.most_similar(positive=[word],topn=n)
378
+ #removeIndex = []
379
+ #currN = copy.deepcopy(n)
380
+ '''for r in range(len(results)):
381
+ if len(results[r][0]) < 2:
382
+ removeIndex.append(results[r])
383
+ # remove the same word but just plural and singular noun and lower than the cos_thres
384
+ elif results[r][0] == word:
385
+ removeIndex.append(results[r])
386
+ elif results[r][0] in simWords or float(results[r][1]) < cos_thres or results[r][0] in stopWords:
387
+ removeIndex.append(results[r])
388
+ for rem in removeIndex:
389
+ results.remove(rem)
390
+ while len(results)!=n and len(results) != 0:
391
+ moreNewResult = model.most_similar(positive=[word],topn=currN+1)[-1]
392
+ if moreNewResult not in results and len(moreNewResult[0])>1:
393
+ if moreNewResult[0] not in stopWords and results[0] != word:
394
+ results.append(moreNewResult)
395
+ currN +=1'''
396
+ return results
397
+ # add more data to existing word2vec model
398
+ def updateWord2Vec(self, modelPath, newCorpus, saveFolder=None):
399
+ if not newCorpus:
400
+ raise ValueError("New corpus is empty!")
401
+
402
+ model = Word2Vec.load(modelPath)
403
+
404
+ # Phrase detection on new data
405
+ phrases = Phrases(newCorpus, min_count=2, threshold=10)
406
+ bigram = Phraser(phrases)
407
+ newCorpus = [bigram[sent] for sent in newCorpus]
408
+
409
+ # Update vocab & retrain
410
+ model.build_vocab(newCorpus, update=True)
411
+ model.train(newCorpus, total_examples=len(newCorpus), epochs=model.epochs)
412
+
413
+ # Save updated model
414
+ if saveFolder:
415
+ os.makedirs(saveFolder, exist_ok=True)
416
+ name = os.path.basename(modelPath).replace(".model", "_updated.model")
417
+ model.save(f"{saveFolder}/{name}")
418
+ print(f"🔁 Model updated and saved to {saveFolder}/{name}")
419
+ else:
420
+ model.save(modelPath)
421
+ print(f"🔁 Model updated and overwritten at {modelPath}")
422
+
423
+ # adding our model into spacy
424
+ # this deals with command line; but instead of using it, we write python script to run command line
425
+ def loadWordVec(self,modelName,wordVec):
426
+ # modelName is the name you want to save into spacy
427
+ # wordVec is the trained word2vec in txt format
428
+ subprocess.run([sys.executable,
429
+ "-m",
430
+ "spacy",
431
+ "init-model",
432
+ "en",
433
+ modelName, # this modelName comes from the saved modelName of function trainWord2Vec
434
+ "--vectors-loc",
435
+ wordVec])
436
+ print("done")
README.md CHANGED
@@ -1,74 +1,80 @@
1
  ---
2
- setup: bash setup.sh
3
- title: MtDNALocation
4
- emoji: 📊
5
  colorFrom: blue
6
- colorTo: purple
7
  sdk: gradio
8
- sdk_version: 5.25.0
9
  app_file: app.py
10
  pinned: false
11
- license: mit
12
- short_description: mtDNA Location Classification tool
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
-
17
- # Installation
18
- ## Set up environments and start GUI:
19
- ```bash
20
- git clone https://github.com/Open-Access-Bio-Data/mtDNA-Location-Classifier.git
21
- ```
22
- If installed using mamba (recommended):
23
- ```bash
24
- mamba env create -f env.yaml
25
- ```
26
- If not, check current python version in terminal and make sure that it is python version 3.10, then run
27
- ```bash
28
- pip install -r requirements.txt
29
- ```
30
- To start the programme, run this in terminal:
31
- ```bash
32
- python app.py
33
- ```
34
- Then follow its instructions
35
- # Descriptions:
36
- mtDNA-Location-Classifier uses [Gradio](https://www.gradio.app/docs) to handle the front-end interactions.
37
-
38
- The programme takes **an accession number** (an NCBI GenBank/nuccore identifier) as input and returns the likely origin of the sequence through `classify_sample_location_cached(accession=accession_number)`. This function wraps around a pipeline that proceeds as follow:
39
- ## Steps 1-3: Check and retrieve base materials: the Pubmed ID, isolate, DOI and text:
40
- - Which are respectively:
41
-
42
- ### Step 1: pubmed_ids and isolates
43
- `get_info_from accession(accession=accession_number)`
44
- - Current input is a string of `accession_number` and output are two lists, one of PUBMED IDs and one of isolate(s).
45
- - Which look through the metadata of the sequence with `accession_number` and extract `PUBMED ID` if available or `isolate` information.
46
- - The presence of PUBMED ID is currently important for the retrieval of texts in the next steps, which are eventually used by method 4.1 (question-answering) and 4.2 (infer from haplogroup)
47
- - Some sequences might not have `isolate` info but its availibity is optional. (as they might be used by method 4.1 and 4.2 as alternative)
48
-
49
- ### Step 2: dois
50
- `get_doi_from_pubmed_id(pubmed_ids = pubmed_ids)`
51
- - Input is a list of PUBMED IDs of the sequence with `accession_number` (retrieved from previous step) and output is a dictionary with keys = PUBMED IDs and values = according DOIs.
52
- - The pubmed_ids are retrieved from the `get_info_from accession(accession=accession_number)` mentioned above.
53
- - The DOIs will be passed down to dependent functions to extract texts of publications to pass on to method 4.1 and 4.2
54
-
55
- ### Step 3: get text
56
- `get_paper_text(dois = dois)`
57
- - Input is currently a list of dois retrieved from previous step and output is a dictionary with keys = sources (doi links or file type) (We might improve this to have other inputs in addition to just doi links - maybe files); values = texts obtained from sources.
58
- - Output of this step is crucial to method 4.1 and 4.2
59
-
60
-
61
- ## Step 4: Prediction of origin:
62
- ### Method 4.0:
63
- - The first method attempts to directly look in the metadata for information that was submitted along with the sequence. Thus, it does not require availability of PUBMED IDs/DOIs or isolates.
64
- - However, this information is not always available in the submission. Thus, we use other methods (4.1,4.2) to retrieve publications through which we can extract the information of the source of mtDNA
65
-
66
- ### Method 4.1:
67
- -
68
-
69
- ### Method 4.2:
70
- -
71
-
72
- ## More in the package
73
- ### extraction of text from HTML
74
- ### extraction of text from PDF
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: mtDNA Location Classifier 🧬
3
+ emoji: 🧬
 
4
  colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # mtDNA Location Classifier
13
+
14
+ ## Publish faster. Analyze smarter.
15
+
16
+ Are you dealing with **incomplete mtDNA metadata** (like country, ethnicity, sample type)?
17
+ This tool helps researchers like you generate **clean, structured labels** — ready to use for your paper.
18
+
19
+ ---
20
+
21
+ ## What You’ll Get:
22
+
23
+ - Inference from sequence ID alone
24
+ - Handles hard edge cases (e.g no clear PubMedID, Direct Submission, etc.)
25
+ - Clear sample type, country, and more (ethnicity, phenotype, etc.)
26
+ - Excel export with citations
27
+ - Feedback-based refund policy
28
+
29
+ ---
30
+
31
+ ## Free Tier
32
+
33
+ - 30 free samples no email needed
34
+ - +20 bonus samples + Excel file when you enter your email
35
+ - Don’t like the result? Tell us why we won’t count the bad ones (email required)
36
+
37
+ ---
38
+
39
+ ## Pricing Pay As You Go (DIY)
40
+
41
+ | Case Type | Price/Sample | Output |
42
+ |-------------|--------------|----------------------------------------|
43
+ | Normal | $0.10 | Sample Type + Country |
44
+ | Edge | $1.00 | Sample Type + Country |
45
+ | Niche | $2.00 | Sample Type + Country + 1 Custom Label |
46
+
47
+ ---
48
+
49
+ ## Batch Discount (1000+ Samples)
50
+
51
+ - **Normal Output** → $100 total ($0.10/sample)
52
+ Unsatisfied samples? We’ll refund them.
53
+
54
+ - **Niche Output** $500 total ($0.50/sample)
55
+ Includes an extra label like ethnicity or phenotype.
56
+
57
+ ---
58
+
59
+ ## Early User Bonus (Limited!)
60
+
61
+ Are you one of our **first 10 paying users**?
62
+ Just type `early_user` in your email.
63
+
64
+ You'll get **20% lifetime discount** on every plan — forever.
65
+ We’ll apply this automatically so you don’t have to calculate anything.
66
+
67
+ ## Our Mission Building This
68
+
69
+ Give clean, high-quality, open-access biological datasets that save researchers time and improve scientific reproducibility.
70
+ Build the world's clean, AI-driven open metadata source for biological research.
71
+
72
+ ---
73
+
74
+ ## Try It Now
75
+
76
+ Paste your sequence ID on our demo:
77
+ [Try the Classifier](https://huggingface.co/spaces/VyLala/mtDNALocation)
78
+
79
+ Need help or bulk analysis?
80
README_OLD_VERSION.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ setup: bash setup.sh
3
+ title: MtDNALocation
4
+ emoji: 📊
5
+ colorFrom: blue
6
+ colorTo: purple
7
+ sdk: gradio
8
+ sdk_version: 5.25.0
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ short_description: mtDNA Location Classification tool
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ # Installation
18
+ ## Set up environments and start GUI:
19
+ ```bash
20
+ git clone https://github.com/Open-Access-Bio-Data/mtDNA-Location-Classifier.git
21
+ ```
22
+ If installed using mamba (recommended):
23
+ ```bash
24
+ mamba env create -f env.yaml
25
+ ```
26
+ If not, check current python version in terminal and make sure that it is python version 3.10, then run
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+ To start the programme, run this in terminal:
31
+ ```bash
32
+ python app.py
33
+ ```
34
+ Then follow its instructions
35
+ # Descriptions:
36
+ mtDNA-Location-Classifier uses [Gradio](https://www.gradio.app/docs) to handle the front-end interactions.
37
+
38
+ The programme takes **an accession number** (an NCBI GenBank/nuccore identifier) as input and returns the likely origin of the sequence through `classify_sample_location_cached(accession=accession_number)`. This function wraps around a pipeline that proceeds as follow:
39
+ ## Steps 1-3: Check and retrieve base materials: the Pubmed ID, isolate, DOI and text:
40
+ - Which are respectively:
41
+
42
+ ### Step 1: pubmed_ids and isolates
43
+ `get_info_from accession(accession=accession_number)`
44
+ - Current input is a string of `accession_number` and output are two lists, one of PUBMED IDs and one of isolate(s).
45
+ - Which look through the metadata of the sequence with `accession_number` and extract `PUBMED ID` if available or `isolate` information.
46
+ - The presence of PUBMED ID is currently important for the retrieval of texts in the next steps, which are eventually used by method 4.1 (question-answering) and 4.2 (infer from haplogroup)
47
+ - Some sequences might not have `isolate` info but its availibity is optional. (as they might be used by method 4.1 and 4.2 as alternative)
48
+
49
+ ### Step 2: dois
50
+ `get_doi_from_pubmed_id(pubmed_ids = pubmed_ids)`
51
+ - Input is a list of PUBMED IDs of the sequence with `accession_number` (retrieved from previous step) and output is a dictionary with keys = PUBMED IDs and values = according DOIs.
52
+ - The pubmed_ids are retrieved from the `get_info_from accession(accession=accession_number)` mentioned above.
53
+ - The DOIs will be passed down to dependent functions to extract texts of publications to pass on to method 4.1 and 4.2
54
+
55
+ ### Step 3: get text
56
+ `get_paper_text(dois = dois)`
57
+ - Input is currently a list of dois retrieved from previous step and output is a dictionary with keys = sources (doi links or file type) (We might improve this to have other inputs in addition to just doi links - maybe files); values = texts obtained from sources.
58
+ - Output of this step is crucial to method 4.1 and 4.2
59
+
60
+
61
+ ## Step 4: Prediction of origin:
62
+ ### Method 4.0:
63
+ - The first method attempts to directly look in the metadata for information that was submitted along with the sequence. Thus, it does not require availability of PUBMED IDs/DOIs or isolates.
64
+ - However, this information is not always available in the submission. Thus, we use other methods (4.1,4.2) to retrieve publications through which we can extract the information of the source of mtDNA
65
+
66
+ ### Method 4.1:
67
+ -
68
+
69
+ ### Method 4.2:
70
+ -
71
+
72
+ ## More in the package
73
+ ### extraction of text from HTML
74
+ ### extraction of text from PDF
app.py CHANGED
@@ -1,697 +1,793 @@
1
- import gradio as gr
2
- import mtdna_backend
3
- import json
4
- import data_preprocess, model, pipeline
5
- import os
6
- import hashlib
7
- import threading
8
- # Gradio UI
9
- #stop_flag = gr.State(value=False)
10
- class StopFlag:
11
- def __init__(self):
12
- self.value = False
13
- global_stop_flag = StopFlag() # Shared between run + stop
14
-
15
- with gr.Blocks() as interface:
16
- gr.Markdown("# 🧬 mtDNA Location Classifier (MVP)")
17
-
18
- #inputMode = gr.Radio(choices=["Single Accession", "Batch Input"], value="Single Accession", label="Choose Input Mode")
19
- user_email = gr.Textbox(label="📧 Your email (used to track free quota)")
20
- usage_display = gr.Markdown("", visible=False)
21
-
22
- # with gr.Group() as single_input_group:
23
- # single_accession = gr.Textbox(label="Enter Single Accession (e.g., KU131308)")
24
-
25
- # with gr.Group(visible=False) as batch_input_group:
26
- # raw_text = gr.Textbox(label="🧬 Paste Accession Numbers (e.g., MF362736.1,MF362738.1,KU131308,MW291678)")
27
- # resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
28
- # gr.HTML("""<a href="https://drive.google.com/file/d/1t-TFeIsGVu5Jh3CUZS-VE9jQWzNFCs_c/view?usp=sharing" download target="_blank">Download Example CSV Format</a>""")
29
- # gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing&ouid=112390323314156876153&rtpof=true&sd=true" download target="_blank">Download Example Excel Format</a>""")
30
- # file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True, elem_id="file-upload-box")
31
- raw_text = gr.Textbox(label="🧚 Input Accession Number(s) (single (KU131308) or comma-separated (e.g., MF362736.1,MF362738.1,KU131308,MW291678))")
32
- #resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
33
- gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing" download target="_blank">Download Example Excel Format</a>""")
34
- file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True)
35
-
36
- with gr.Row():
37
- run_button = gr.Button("🔍 Submit and Classify")
38
- stop_button = gr.Button(" Stop Batch", visible=True)
39
- reset_button = gr.Button("🔄 Reset")
40
-
41
- status = gr.Markdown(visible=False)
42
-
43
- with gr.Group(visible=False) as results_group:
44
- # with gr.Accordion("Open to See the Result", open=False) as results:
45
- # with gr.Row():
46
- # output_summary = gr.Markdown(elem_id="output-summary")
47
- # output_flag = gr.Markdown(elem_id="output-flag")
48
-
49
- # gr.Markdown("---")
50
-
51
- with gr.Accordion("Open to See the Output Table", open=False) as table_accordion:
52
- output_table = gr.HTML(render=True)
53
-
54
- #with gr.Row():
55
- #output_type = gr.Dropdown(choices=["Excel", "JSON", "TXT"], label="Select Output Format", value="Excel")
56
- #download_button = gr.Button("⬇️ Download Output")
57
- #download_file = gr.File(label="Download File Here",visible=False)
58
- download_file = gr.File(label="Download File Here", visible=False, interactive=True)
59
- progress_box = gr.Textbox(label="Live Processing Log", lines=20, interactive=False)
60
-
61
- gr.Markdown("---")
62
-
63
- gr.Markdown("### 💬 Feedback (required)")
64
- q1 = gr.Textbox(label="1️⃣ Was the inferred location accurate or helpful? Please explain.")
65
- q2 = gr.Textbox(label="2️⃣ What would improve your experience with this tool?")
66
- contact = gr.Textbox(label="📧 Your email or institution (optional)")
67
- submit_feedback = gr.Button("✅ Submit Feedback")
68
- feedback_status = gr.Markdown()
69
-
70
- # Functions
71
- # def toggle_input_mode(mode):
72
- # if mode == "Single Accession":
73
- # return gr.update(visible=True), gr.update(visible=False)
74
- # else:
75
- # return gr.update(visible=False), gr.update(visible=True)
76
-
77
- def classify_with_loading():
78
- return gr.update(value=" Please wait... processing...",visible=True) # Show processing message
79
-
80
- # def classify_dynamic(single_accession, file, text, resume, email, mode):
81
- # if mode == "Single Accession":
82
- # return classify_main(single_accession) + (gr.update(visible=False),)
83
- # else:
84
- # #return summarize_batch(file, text) + (gr.update(visible=False),) # Hide processing message
85
- # return classify_mulAcc(file, text, resume) + (gr.update(visible=False),) # Hide processing message
86
- # Logging helpers defined early to avoid NameError
87
-
88
-
89
- # def classify_dynamic(single_accession, file, text, resume, email, mode):
90
- # if mode == "Single Accession":
91
- # return classify_main(single_accession) + (gr.update(value="", visible=False),)
92
- # else:
93
- # return classify_mulAcc(file, text, resume, email, log_callback=real_time_logger, log_collector=log_collector)
94
-
95
- # for single accession
96
- # def classify_main(accession):
97
- # #table, summary, labelAncient_Modern, explain_label = mtdna_backend.summarize_results(accession)
98
- # table = mtdna_backend.summarize_results(accession)
99
- # #flag_output = f"### 🏺 Ancient/Modern Flag\n**{labelAncient_Modern}**\n\n_Explanation:_ {explain_label}"
100
- # return (
101
- # #table,
102
- # make_html_table(table),
103
- # # summary,
104
- # # flag_output,
105
- # gr.update(visible=True),
106
- # gr.update(visible=False),
107
- # gr.update(visible=False)
108
- # )
109
-
110
- #stop_flag = gr.State(value=False)
111
- #stop_flag = StopFlag()
112
-
113
- # def stop_batch(stop_flag):
114
- # stop_flag.value = True
115
- # return gr.update(value="❌ Stopping...", visible=True), stop_flag
116
- def stop_batch():
117
- global_stop_flag.value = True
118
- return gr.update(value="❌ Stopping...", visible=True)
119
-
120
- # def threaded_batch_runner(file, text, email):
121
- # global_stop_flag.value = False
122
- # log_lines = []
123
-
124
- # def update_log(line):
125
- # log_lines.append(line)
126
- # yield (
127
- # gr.update(visible=False), # output_table (not yet)
128
- # gr.update(visible=False), # results_group
129
- # gr.update(visible=False), # download_file
130
- # gr.update(visible=False), # usage_display
131
- # gr.update(value=" Still processing...", visible=True), # status
132
- # gr.update(value="\n".join(log_lines)) # progress_box
133
- # )
134
-
135
- # # Start a dummy update to say "Starting..."
136
- # yield from update_log("🚀 Starting batch processing...")
137
-
138
- # rows, file_path, count, final_log, warning = mtdna_backend.summarize_batch(
139
- # file=file,
140
- # raw_text=text,
141
- # resume_file=None,
142
- # user_email=email,
143
- # stop_flag=global_stop_flag,
144
- # yield_callback=lambda line: (yield from update_log(line))
145
- # )
146
-
147
- # html = make_html_table(rows)
148
- # file_update = gr.update(value=file_path, visible=True) if os.path.exists(file_path) else gr.update(visible=False)
149
- # usage_or_warning_text = f"**{count}** samples used by this email." if email.strip() else warning
150
-
151
- # yield (
152
- # html,
153
- # gr.update(visible=True), # results_group
154
- # file_update, # download_file
155
- # gr.update(value=usage_or_warning_text, visible=True),
156
- # gr.update(value="✅ Done", visible=True),
157
- # gr.update(value=final_log)
158
- # )
159
-
160
- # def threaded_batch_runner(file=None, text="", email=""):
161
- # print("📧 EMAIL RECEIVED:", email)
162
- # import tempfile
163
- # from mtdna_backend import (
164
- # extract_accessions_from_input,
165
- # summarize_results,
166
- # save_to_excel,
167
- # hash_user_id,
168
- # increment_usage,
169
- # )
170
- # import os
171
-
172
- # global_stop_flag.value = False # reset stop flag
173
-
174
- # tmp_dir = tempfile.mkdtemp()
175
- # output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
176
- # limited_acc = 50 + (10 if email.strip() else 0)
177
-
178
- # # Step 1: Parse input
179
- # accessions, error = extract_accessions_from_input(file, text)
180
- # print(accessions)
181
- # if error:
182
- # yield (
183
- # "", # output_table
184
- # gr.update(visible=False), # results_group
185
- # gr.update(visible=False), # download_file
186
- # "", # usage_display
187
- # "❌ Error", # status
188
- # str(error) # progress_box
189
- # )
190
- # return
191
-
192
- # total = len(accessions)
193
- # if total > limited_acc:
194
- # accessions = accessions[:limited_acc]
195
- # warning = f"⚠️ Only processing first {limited_acc} accessions."
196
- # else:
197
- # warning = f"✅ All {total} accessions will be processed."
198
-
199
- # all_rows = []
200
- # processed_accessions = 0 # ✅ tracks how many accessions were processed
201
- # email_tracked = False
202
- # log_lines = []
203
-
204
- # # Step 2: Loop through accessions
205
- # for i, acc in enumerate(accessions):
206
- # if global_stop_flag.value:
207
- # log_lines.append(f"🛑 Stopped at {acc} ({i+1}/{total})")
208
- # usage_text = ""
209
- # if email.strip() and not email_tracked:
210
- # # user_hash = hash_user_id(email)
211
- # # usage_count = increment_usage(user_hash, len(all_rows))
212
- # print("print(processed_accessions at stop) ",processed_accessions)
213
- # usage_count = increment_usage(email, processed_accessions)
214
- # email_tracked = True
215
- # usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
216
- # else:
217
- # usage_text = f"The limited accession is 50. The user has used {processed_accessions}, and only {50-processed_accessions} left."
218
- # yield (
219
- # make_html_table(all_rows),
220
- # gr.update(visible=True),
221
- # gr.update(value=output_file_path, visible=True),
222
- # gr.update(value=usage_text, visible=True),
223
- # "🛑 Stopped",
224
- # "\n".join(log_lines)
225
- # )
226
- # return
227
-
228
- # log_lines.append(f"[{i+1}/{total}] Processing {acc}")
229
- # yield (
230
- # make_html_table(all_rows),
231
- # gr.update(visible=True),
232
- # gr.update(visible=False),
233
- # "",
234
- # "⏳ Processing...",
235
- # "\n".join(log_lines)
236
- # )
237
-
238
- # try:
239
- # print(acc)
240
- # rows = summarize_results(acc)
241
- # all_rows.extend(rows)
242
- # processed_accessions += 1 # ✅ count only successful accessions
243
- # save_to_excel(all_rows, "", "", output_file_path, is_resume=False)
244
- # log_lines.append(f"✅ Processed {acc} ({i+1}/{total})")
245
- # except Exception as e:
246
- # log_lines.append(f"❌ Failed to process {acc}: {e}")
247
-
248
- # yield (
249
- # make_html_table(all_rows),
250
- # gr.update(visible=True),
251
- # gr.update(visible=False),
252
- # "",
253
- # "⏳ Processing...",
254
- # "\n".join(log_lines)
255
- # )
256
-
257
- # # Final update
258
- # usage_text = ""
259
-
260
- # if email.strip() and not email_tracked:
261
- # # user_hash = hash_user_id(email)
262
- # # usage_count = increment_usage(user_hash, len(all_rows))
263
- # print("print(processed_accessions final) ",processed_accessions)
264
- # usage_count = increment_usage(email, processed_accessions)
265
- # usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
266
- # elif not email.strip():
267
- # usage_text = f"The limited accession is 50. The user has used {processed_accessions}, and only {50-processed_accessions} left."
268
- # yield (
269
- # make_html_table(all_rows),
270
- # gr.update(visible=True),
271
- # gr.update(value=output_file_path, visible=True),
272
- # gr.update(value=usage_text, visible=True),
273
- # "✅ Done",
274
- # "\n".join(log_lines)
275
- # )
276
- def threaded_batch_runner(file=None, text="", email=""):
277
- print("📧 EMAIL RECEIVED:", repr(email))
278
- import tempfile
279
- from mtdna_backend import (
280
- extract_accessions_from_input,
281
- summarize_results,
282
- save_to_excel,
283
- increment_usage,
284
- )
285
- import os
286
-
287
- global_stop_flag.value = False # reset stop flag
288
-
289
- tmp_dir = tempfile.mkdtemp()
290
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
291
- limited_acc = 30 + (20 if email.strip() else 0)
292
- # Step 1: Parse input
293
- accessions, error = extract_accessions_from_input(file, text)
294
- print("🧪 Accessions received:", accessions)
295
- if error:
296
- yield (
297
- "", # output_table
298
- gr.update(visible=False), # results_group
299
- gr.update(visible=False), # download_file
300
- "", # usage_display
301
- "❌ Error", # status
302
- str(error) # progress_box
303
- )
304
- return
305
-
306
- total = len(accessions)
307
- if total > limited_acc:
308
- accessions = accessions[:limited_acc]
309
- warning = f"⚠️ Only processing first {limited_acc} accessions."
310
- else:
311
- warning = f" All {total} accessions will be processed."
312
-
313
- all_rows = []
314
- processed_accessions = 0 # ✅ track successful accessions
315
- email_tracked = False
316
- log_lines = []
317
- if not email.strip():
318
- output_file_path = None#"Write your email so that you can download the outputs."
319
- log_lines.append("📥 Provide your email to receive a downloadable Excel report and get 20 more free queries.")
320
- if email.strip():
321
- usage_count = increment_usage(email, processed_accessions)
322
- if int(usage_count) > 50:
323
- log_lines.append("❌ You have reached your quota. Please contact us to unlock more.")
324
-
325
- # Minimal blank yield to trigger UI rendering
326
- yield (
327
- make_html_table([]),
328
- gr.update(visible=True),
329
- gr.update(visible=False),
330
- gr.update(value="", visible=True),
331
- "⛔️ Quota limit",
332
- "⛔️ Quota limit"
333
- )
334
-
335
- # Actual warning frame
336
- yield (
337
- make_html_table([]),
338
- gr.update(visible=False),
339
- gr.update(visible=False),
340
- gr.update(value="❌ You have reached your quota. Please contact us to unlock more.", visible=True),
341
- "❌ Quota Exceeded",
342
- "\n".join(log_lines)
343
- )
344
- return
345
-
346
-
347
- # Step 2: Loop through accessions
348
- for i, acc in enumerate(accessions):
349
- if global_stop_flag.value:
350
- log_lines.append(f"🛑 Stopped at {acc} ({i+1}/{total})")
351
- usage_text = ""
352
-
353
- if email.strip() and not email_tracked:
354
- print(f"🧪 increment_usage at STOP: {email=} {processed_accessions=}")
355
- usage_count = increment_usage(email, processed_accessions)
356
- email_tracked = True
357
- usage_text = f"**{usage_count}**/50 free samples used by this email."
358
- #Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
359
- else:
360
- usage_text = f"The limited accession is 30. The user has used {processed_accessions}, and only {30 - processed_accessions} left."
361
-
362
- yield (
363
- make_html_table(all_rows),
364
- gr.update(visible=True),
365
- #gr.update(value=output_file_path, visible=True),
366
- gr.update(value=output_file_path, visible=bool(output_file_path)),
367
- gr.update(value=usage_text, visible=True),
368
- "🛑 Stopped",
369
- "\n".join(log_lines)
370
- )
371
- return
372
-
373
- log_lines.append(f"[{i+1}/{total}] Processing {acc}")
374
- yield (
375
- make_html_table(all_rows),
376
- gr.update(visible=True),
377
- gr.update(visible=False),
378
- "",
379
- " Processing...",
380
- "\n".join(log_lines)
381
- )
382
-
383
- try:
384
- print("📄 Processing accession:", acc)
385
- rows = summarize_results(acc)
386
- all_rows.extend(rows)
387
- processed_accessions += 1 # ✅ only count success
388
- if email.strip():
389
- save_to_excel(all_rows, "", "", output_file_path, is_resume=False)
390
- log_lines.append(f"✅ Processed {acc} ({i+1}/{total})")
391
- except Exception as e:
392
- log_lines.append(f"❌ Failed to process {acc}: {e}")
393
-
394
- yield (
395
- make_html_table(all_rows),
396
- gr.update(visible=True),
397
- gr.update(visible=False),
398
- "",
399
- "⏳ Processing...",
400
- "\n".join(log_lines)
401
- )
402
-
403
- # Step 3: Final usage update
404
- usage_text = ""
405
- if email.strip() and not email_tracked:
406
- print(f"🧪 increment_usage at END: {email=} {processed_accessions=}")
407
- usage_count = increment_usage(email, processed_accessions)
408
- email_tracked = True
409
- usage_text = f"**{usage_count}**/50 free samples used by this email."
410
- #Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
411
- elif not email.strip():
412
- usage_text = f"The limited accession is 30. The user has used {processed_accessions}, and only {30 - processed_accessions} left."
413
-
414
- yield (
415
- make_html_table(all_rows),
416
- gr.update(visible=True),
417
- #gr.update(value=output_file_path, visible=True),
418
- gr.update(value=output_file_path, visible=bool(output_file_path)),
419
- gr.update(value=usage_text, visible=True),
420
- "✅ Done",
421
- "\n".join(log_lines)
422
- )
423
-
424
-
425
- # def threaded_batch_runner(file=None, text="", email=""):
426
- # global_stop_flag.value = False
427
-
428
- # # Dummy test output that matches expected schema
429
- # return (
430
- # "<div>✅ Dummy output table</div>", # HTML string
431
- # gr.update(visible=True), # Group visibility
432
- # gr.update(visible=False), # Download file
433
- # "**0** samples used.", # Markdown
434
- # "✅ Done", # Status string
435
- # "Processing finished." # Progress string
436
- # )
437
-
438
-
439
- # def classify_mulAcc(file, text, resume, email, log_callback=None, log_collector=None):
440
- # stop_flag.value = False
441
- # return threaded_batch_runner(file, text, resume, email, status, stop_flag, log_callback=log_callback, log_collector=log_collector)
442
-
443
-
444
- def make_html_table(rows):
445
- # html = """
446
- # <div style='overflow-x: auto; padding: 10px;'>
447
- # <div style='max-height: 400px; overflow-y: auto; border: 1px solid #444; border-radius: 8px;'>
448
- # <table style='width:100%; border-collapse: collapse; table-layout: auto; font-size: 14px; color: #f1f1f1; background-color: #1e1e1e;'>
449
- # <thead style='position: sticky; top: 0; background-color: #2c2c2c; z-index: 1;'>
450
- # <tr>
451
- # """
452
- html = """
453
- <div style='overflow-x: auto; padding: 10px;'>
454
- <div style='max-height: 400px; overflow-y: auto; border: 1px solid #ccc; border-radius: 8px;'>
455
- <table style='width:100%; border-collapse: collapse; table-layout: auto; font-size: 14px; color: inherit; background-color: inherit;'>
456
- """
457
-
458
- headers = ["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]
459
- html += "".join(
460
- f"<th style='padding: 10px; border: 1px solid #555; text-align: left; white-space: nowrap;'>{h}</th>"
461
- for h in headers
462
- )
463
- html += "</tr></thead><tbody>"
464
-
465
- for row in rows:
466
- html += "<tr>"
467
- for i, col in enumerate(row):
468
- header = headers[i]
469
- style = "padding: 10px; border: 1px solid #555; vertical-align: top;"
470
-
471
- # For specific columns like Haplogroup, force nowrap
472
- if header in ["Country Explanation", "Sample Type Explanation"]:
473
- style += " max-width: 400px; word-wrap: break-word; white-space: normal;"
474
- elif header in ["Sample ID", "Predicted Country", "Predicted Sample Type", "Time cost"]:
475
- style += " white-space: nowrap; text-overflow: ellipsis; max-width: 200px; overflow: hidden;"
476
-
477
- # if header == "Sources" and isinstance(col, str) and col.strip().lower().startswith("http"):
478
- # col = f"<a href='{col}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{col}</a>"
479
-
480
- #html += f"<td style='{style}'>{col}</td>"
481
- if header == "Sources" and isinstance(col, str):
482
- links = [f"<a href='{url.strip()}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{url.strip()}</a>" for url in col.strip().split("\n") if url.strip()]
483
- col = "- "+"<br>- ".join(links)
484
- elif isinstance(col, str):
485
- # lines = []
486
- # for line in col.split("\n"):
487
- # line = line.strip()
488
- # if not line:
489
- # continue
490
- # if line.lower().startswith("rag_llm-"):
491
- # content = line[len("rag_llm-"):].strip()
492
- # line = f"{content} (Method: RAG_LLM)"
493
- # lines.append(f"- {line}")
494
- col = col.replace("\n", "<br>")
495
- #col = col.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
496
- #col = "<br>".join(lines)
497
-
498
- html += f"<td style='{style}'>{col}</td>"
499
- html += "</tr>"
500
-
501
- html += "</tbody></table></div></div>"
502
- return html
503
-
504
-
505
- # def reset_fields():
506
- # global_stop_flag.value = False # 💡 Add this to reset the flag
507
- # return (
508
- # #gr.update(value=""), # single_accession
509
- # gr.update(value=""), # raw_text
510
- # gr.update(value=None), # file_upload
511
- # #gr.update(value=None), # resume_file
512
- # #gr.update(value="Single Accession"), # inputMode
513
- # gr.update(value=[], visible=True), # output_table
514
- # # gr.update(value="", visible=True), # output_summary
515
- # # gr.update(value="", visible=True), # output_flag
516
- # gr.update(visible=False), # status
517
- # gr.update(visible=False), # results_group
518
- # gr.update(value="", visible=False), # usage_display
519
- # gr.update(value="", visible=False), # progress_box
520
- # )
521
- def reset_fields():
522
- global_stop_flag.value = False # Reset the stop flag
523
-
524
- return (
525
- gr.update(value=""), # raw_text
526
- gr.update(value=None), # file_upload
527
- gr.update(value=[], visible=True), # output_table
528
- gr.update(value="", visible=True), # status — reset and make visible again
529
- gr.update(visible=False), # results_group
530
- gr.update(value="", visible=True), # usage_display — reset and make visible again
531
- gr.update(value="", visible=True), # progress_box reset AND visible!
532
- )
533
- #inputMode.change(fn=toggle_input_mode, inputs=inputMode, outputs=[single_input_group, batch_input_group])
534
- #run_button.click(fn=classify_with_loading, inputs=[], outputs=[status])
535
- # run_button.click(
536
- # fn=classify_dynamic,
537
- # inputs=[single_accession, file_upload, raw_text, resume_file,user_email,inputMode],
538
- # outputs=[output_table,
539
- # #output_summary, output_flag,
540
- # results_group, download_file, usage_display,status, progress_box]
541
- # )
542
-
543
- # run_button.click(
544
- # fn=threaded_batch_runner,
545
- # #inputs=[file_upload, raw_text, resume_file, user_email],
546
- # inputs=[file_upload, raw_text, user_email],
547
- # outputs=[output_table, results_group, download_file, usage_display, status, progress_box]
548
- # )
549
- # run_button.click(
550
- # fn=threaded_batch_runner,
551
- # inputs=[file_upload, raw_text, user_email],
552
- # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
553
- # every=0.5 # <-- this tells Gradio to expect streaming
554
- # )
555
- # output_table = gr.HTML()
556
- # results_group = gr.Group(visible=False)
557
- # download_file = gr.File(visible=False)
558
- # usage_display = gr.Markdown(visible=False)
559
- # status = gr.Markdown(visible=False)
560
- # progress_box = gr.Textbox(visible=False)
561
-
562
- # run_button.click(
563
- # fn=threaded_batch_runner,
564
- # inputs=[file_upload, raw_text, user_email],
565
- # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
566
- # every=0.5, # streaming enabled
567
- # show_progress="full"
568
- # )
569
-
570
- # interface.stream(
571
- # fn=threaded_batch_runner,
572
- # inputs=[file_upload, raw_text, user_email],
573
- # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
574
- # trigger=run_button,
575
- # every=0.5,
576
- # show_progress="full",
577
- # )
578
- interface.queue() # No arguments here!
579
-
580
- run_button.click(
581
- fn=threaded_batch_runner,
582
- inputs=[file_upload, raw_text, user_email],
583
- outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
584
- concurrency_limit=1, # correct in Gradio 5.x
585
- queue=True, # ensure the queue is used
586
- #every=0.5
587
- )
588
-
589
-
590
-
591
-
592
- stop_button.click(fn=stop_batch, inputs=[], outputs=[status])
593
-
594
- # reset_button.click(
595
- # #fn=reset_fields,
596
- # fn=lambda: (
597
- # gr.update(value=""), gr.update(value=""), gr.update(value=None), gr.update(value=None), gr.update(value="Single Accession"),
598
- # gr.update(value=[], visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="", visible=False)
599
- # ),
600
- # inputs=[],
601
- # outputs=[
602
- # single_accession, raw_text, file_upload, resume_file,inputMode,
603
- # output_table,# output_summary, output_flag,
604
- # status, results_group, usage_display, progress_box
605
- # ]
606
- # )
607
- #stop_button.click(fn=lambda sf: (gr.update(value="❌ Stopping...", visible=True), setattr(sf, "value", True) or sf), inputs=[gr.State(stop_flag)], outputs=[status, gr.State(stop_flag)])
608
-
609
- reset_button.click(
610
- fn=reset_fields,
611
- inputs=[],
612
- #outputs=[raw_text, file_upload, resume_file, output_table, status, results_group, usage_display, progress_box]
613
- outputs=[raw_text, file_upload, output_table, status, results_group, usage_display, progress_box]
614
- )
615
-
616
- # download_button.click(
617
- # fn=mtdna_backend.save_batch_output,
618
- # #inputs=[output_table, output_summary, output_flag, output_type],
619
- # inputs=[output_table, output_type],
620
- # outputs=[download_file])
621
-
622
- # submit_feedback.click(
623
- # fn=mtdna_backend.store_feedback_to_google_sheets,
624
- # inputs=[single_accession, q1, q2, contact], outputs=feedback_status
625
- # )
626
- submit_feedback.click(
627
- fn=mtdna_backend.store_feedback_to_google_sheets,
628
- inputs=[raw_text, q1, q2, contact],
629
- outputs=[feedback_status]
630
- )
631
- gr.HTML("""
632
- <style>
633
- body, html {
634
- background-color: #121212 !important;
635
- color: #ffffff !important;
636
- }
637
-
638
- .gradio-container, .gr-block, .gr-box, textarea, input, select, .prose, .prose * {
639
- background-color: #1e1e1e !important;
640
- color: #ffffff !important;
641
- border-color: #333 !important;
642
- }
643
-
644
- textarea::placeholder,
645
- input::placeholder {
646
- color: #aaa !important;
647
- }
648
-
649
- button {
650
- background-color: #2d2d2d !important;
651
- color: #fff !important;
652
- border: 1px solid #444 !important;
653
- }
654
-
655
- a {
656
- color: #4ea1f3 !important;
657
- }
658
- </style>
659
- """)
660
-
661
- # # Custom CSS styles
662
- # gr.HTML("""
663
- # <style>
664
- # /* Ensures both sections are equally spaced with the same background size */
665
- # #output-summary, #output-flag {
666
- # background-color: #f0f4f8; /* Light Grey for both */
667
- # padding: 20px;
668
- # border-radius: 10px;
669
- # margin-top: 10px;
670
- # width: 100%; /* Ensure full width */
671
- # min-height: 150px; /* Ensures both have a minimum height */
672
- # box-sizing: border-box; /* Prevents padding from increasing size */
673
- # display: flex;
674
- # flex-direction: column;
675
- # justify-content: space-between;
676
- # }
677
-
678
- # /* Specific background colors */
679
- # #output-summary {
680
- # background-color: #434a4b;
681
- # }
682
-
683
- # #output-flag {
684
- # background-color: #141616;
685
- # }
686
-
687
- # /* Ensuring they are in a row and evenly spaced */
688
- # .gradio-row {
689
- # display: flex;
690
- # justify-content: space-between;
691
- # width: 100%;
692
- # }
693
- # </style>
694
- # """)
695
-
696
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
  interface.launch(share=True,debug=True)
 
1
+ import gradio as gr
2
+ import mtdna_backend
3
+ import json
4
+ import data_preprocess, model, pipeline
5
+ import os
6
+ import hashlib
7
+ import threading
8
+ # Gradio UI
9
+ #stop_flag = gr.State(value=False)
10
+ class StopFlag:
11
+ def __init__(self):
12
+ self.value = False
13
+ global_stop_flag = StopFlag() # Shared between run + stop
14
+
15
+ with open("offer.html", "r", encoding="utf-8") as f:
16
+ pricing_html = f.read()
17
+
18
+ with gr.Blocks() as interface:
19
+ with gr.Tab("CURIOUS ABOUT THIS PRODUCT?"):
20
+ gr.HTML(value=pricing_html)
21
+ with gr.Tab("🧬 Classifier"):
22
+ gr.Markdown("# 🧬 mtDNA Location Classifier (MVP)")
23
+ #inputMode = gr.Radio(choices=["Single Accession", "Batch Input"], value="Single Accession", label="Choose Input Mode")
24
+ user_email = gr.Textbox(label="📧 Your email (used to track free quota)")
25
+ usage_display = gr.Markdown("", visible=False)
26
+
27
+ # with gr.Group() as single_input_group:
28
+ # single_accession = gr.Textbox(label="Enter Single Accession (e.g., KU131308)")
29
+
30
+ # with gr.Group(visible=False) as batch_input_group:
31
+ # raw_text = gr.Textbox(label="🧬 Paste Accession Numbers (e.g., MF362736.1,MF362738.1,KU131308,MW291678)")
32
+ # resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
33
+ # gr.HTML("""<a href="https://drive.google.com/file/d/1t-TFeIsGVu5Jh3CUZS-VE9jQWzNFCs_c/view?usp=sharing" download target="_blank">Download Example CSV Format</a>""")
34
+ # gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing&ouid=112390323314156876153&rtpof=true&sd=true" download target="_blank">Download Example Excel Format</a>""")
35
+ # file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True, elem_id="file-upload-box")
36
+ raw_text = gr.Textbox(label="🧚 Input Accession Number(s) (single (KU131308) or comma-separated (e.g., MF362736.1,MF362738.1,KU131308,MW291678))")
37
+ #resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
38
+ gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing" download target="_blank">Download Example Excel Format</a>""")
39
+ file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True)
40
+
41
+ with gr.Row():
42
+ run_button = gr.Button("🔍 Submit and Classify")
43
+ stop_button = gr.Button("❌ Stop Batch", visible=True)
44
+ reset_button = gr.Button("🔄 Reset")
45
+
46
+ status = gr.Markdown(visible=False)
47
+
48
+ with gr.Group(visible=False) as results_group:
49
+ # with gr.Accordion("Open to See the Result", open=False) as results:
50
+ # with gr.Row():
51
+ # output_summary = gr.Markdown(elem_id="output-summary")
52
+ # output_flag = gr.Markdown(elem_id="output-flag")
53
+
54
+ # gr.Markdown("---")
55
+
56
+ with gr.Accordion("Open to See the Output Table", open=False) as table_accordion:
57
+ output_table = gr.HTML(render=True)
58
+ #with gr.Row():
59
+ #output_type = gr.Dropdown(choices=["Excel", "JSON", "TXT"], label="Select Output Format", value="Excel")
60
+ #download_button = gr.Button("⬇️ Download Output")
61
+ #download_file = gr.File(label="Download File Here",visible=False)
62
+ report_button = gr.Button("Report")
63
+ report_textbox = gr.Textbox(
64
+ label="Describe the issue",
65
+ lines=4,
66
+ placeholder="e.g. DQ981467: it gives me unknown when I can in fact search it on NCBI \n DQ981467: same reason as above",
67
+ visible=False)
68
+ submit_report_button = gr.Button("Submit", visible=False)
69
+ status_report = gr.Markdown(visible=False)
70
+
71
+ download_file = gr.File(label="Download File Here", visible=False, interactive=True)
72
+ progress_box = gr.Textbox(label="Live Processing Log", lines=20, interactive=False)
73
+
74
+ gr.Markdown("---")
75
+
76
+ gr.Markdown("### 💬 Feedback (required)")
77
+ q1 = gr.Textbox(label="1️⃣ Was the inferred location accurate or helpful? Please explain.")
78
+ q2 = gr.Textbox(label="2️⃣ What would improve your experience with this tool?")
79
+ contact = gr.Textbox(label="📧 Your email or institution (optional)")
80
+ submit_feedback = gr.Button("✅ Submit Feedback")
81
+ feedback_status = gr.Markdown()
82
+
83
+ # Functions
84
+ # def toggle_input_mode(mode):
85
+ # if mode == "Single Accession":
86
+ # return gr.update(visible=True), gr.update(visible=False)
87
+ # else:
88
+ # return gr.update(visible=False), gr.update(visible=True)
89
+
90
+ def classify_with_loading():
91
+ return gr.update(value="⏳ Please wait... processing...",visible=True) # Show processing message
92
+
93
+ # def classify_dynamic(single_accession, file, text, resume, email, mode):
94
+ # if mode == "Single Accession":
95
+ # return classify_main(single_accession) + (gr.update(visible=False),)
96
+ # else:
97
+ # #return summarize_batch(file, text) + (gr.update(visible=False),) # Hide processing message
98
+ # return classify_mulAcc(file, text, resume) + (gr.update(visible=False),) # Hide processing message
99
+ # Logging helpers defined early to avoid NameError
100
+
101
+
102
+ # def classify_dynamic(single_accession, file, text, resume, email, mode):
103
+ # if mode == "Single Accession":
104
+ # return classify_main(single_accession) + (gr.update(value="", visible=False),)
105
+ # else:
106
+ # return classify_mulAcc(file, text, resume, email, log_callback=real_time_logger, log_collector=log_collector)
107
+
108
+ # for single accession
109
+ # def classify_main(accession):
110
+ # #table, summary, labelAncient_Modern, explain_label = mtdna_backend.summarize_results(accession)
111
+ # table = mtdna_backend.summarize_results(accession)
112
+ # #flag_output = f"### 🏺 Ancient/Modern Flag\n**{labelAncient_Modern}**\n\n_Explanation:_ {explain_label}"
113
+ # return (
114
+ # #table,
115
+ # make_html_table(table),
116
+ # # summary,
117
+ # # flag_output,
118
+ # gr.update(visible=True),
119
+ # gr.update(visible=False),
120
+ # gr.update(visible=False)
121
+ # )
122
+
123
+ #stop_flag = gr.State(value=False)
124
+ #stop_flag = StopFlag()
125
+
126
+ # def stop_batch(stop_flag):
127
+ # stop_flag.value = True
128
+ # return gr.update(value="❌ Stopping...", visible=True), stop_flag
129
+ def stop_batch():
130
+ global_stop_flag.value = True
131
+ return gr.update(value=" Stopping...", visible=True)
132
+
133
+ # def threaded_batch_runner(file, text, email):
134
+ # global_stop_flag.value = False
135
+ # log_lines = []
136
+
137
+ # def update_log(line):
138
+ # log_lines.append(line)
139
+ # yield (
140
+ # gr.update(visible=False), # output_table (not yet)
141
+ # gr.update(visible=False), # results_group
142
+ # gr.update(visible=False), # download_file
143
+ # gr.update(visible=False), # usage_display
144
+ # gr.update(value="⏳ Still processing...", visible=True), # status
145
+ # gr.update(value="\n".join(log_lines)) # progress_box
146
+ # )
147
+
148
+ # # Start a dummy update to say "Starting..."
149
+ # yield from update_log("🚀 Starting batch processing...")
150
+
151
+ # rows, file_path, count, final_log, warning = mtdna_backend.summarize_batch(
152
+ # file=file,
153
+ # raw_text=text,
154
+ # resume_file=None,
155
+ # user_email=email,
156
+ # stop_flag=global_stop_flag,
157
+ # yield_callback=lambda line: (yield from update_log(line))
158
+ # )
159
+
160
+ # html = make_html_table(rows)
161
+ # file_update = gr.update(value=file_path, visible=True) if os.path.exists(file_path) else gr.update(visible=False)
162
+ # usage_or_warning_text = f"**{count}** samples used by this email." if email.strip() else warning
163
+
164
+ # yield (
165
+ # html,
166
+ # gr.update(visible=True), # results_group
167
+ # file_update, # download_file
168
+ # gr.update(value=usage_or_warning_text, visible=True),
169
+ # gr.update(value="✅ Done", visible=True),
170
+ # gr.update(value=final_log)
171
+ # )
172
+
173
+ # def threaded_batch_runner(file=None, text="", email=""):
174
+ # print("📧 EMAIL RECEIVED:", email)
175
+ # import tempfile
176
+ # from mtdna_backend import (
177
+ # extract_accessions_from_input,
178
+ # summarize_results,
179
+ # save_to_excel,
180
+ # hash_user_id,
181
+ # increment_usage,
182
+ # )
183
+ # import os
184
+
185
+ # global_stop_flag.value = False # reset stop flag
186
+
187
+ # tmp_dir = tempfile.mkdtemp()
188
+ # output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
189
+ # limited_acc = 50 + (10 if email.strip() else 0)
190
+
191
+ # # Step 1: Parse input
192
+ # accessions, error = extract_accessions_from_input(file, text)
193
+ # print(accessions)
194
+ # if error:
195
+ # yield (
196
+ # "", # output_table
197
+ # gr.update(visible=False), # results_group
198
+ # gr.update(visible=False), # download_file
199
+ # "", # usage_display
200
+ # "❌ Error", # status
201
+ # str(error) # progress_box
202
+ # )
203
+ # return
204
+
205
+ # total = len(accessions)
206
+ # if total > limited_acc:
207
+ # accessions = accessions[:limited_acc]
208
+ # warning = f"⚠️ Only processing first {limited_acc} accessions."
209
+ # else:
210
+ # warning = f"✅ All {total} accessions will be processed."
211
+
212
+ # all_rows = []
213
+ # processed_accessions = 0 # ✅ tracks how many accessions were processed
214
+ # email_tracked = False
215
+ # log_lines = []
216
+
217
+ # # Step 2: Loop through accessions
218
+ # for i, acc in enumerate(accessions):
219
+ # if global_stop_flag.value:
220
+ # log_lines.append(f"🛑 Stopped at {acc} ({i+1}/{total})")
221
+ # usage_text = ""
222
+ # if email.strip() and not email_tracked:
223
+ # # user_hash = hash_user_id(email)
224
+ # # usage_count = increment_usage(user_hash, len(all_rows))
225
+ # print("print(processed_accessions at stop) ",processed_accessions)
226
+ # usage_count = increment_usage(email, processed_accessions)
227
+ # email_tracked = True
228
+ # usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
229
+ # else:
230
+ # usage_text = f"The limited accession is 50. The user has used {processed_accessions}, and only {50-processed_accessions} left."
231
+ # yield (
232
+ # make_html_table(all_rows),
233
+ # gr.update(visible=True),
234
+ # gr.update(value=output_file_path, visible=True),
235
+ # gr.update(value=usage_text, visible=True),
236
+ # "🛑 Stopped",
237
+ # "\n".join(log_lines)
238
+ # )
239
+ # return
240
+
241
+ # log_lines.append(f"[{i+1}/{total}] Processing {acc}")
242
+ # yield (
243
+ # make_html_table(all_rows),
244
+ # gr.update(visible=True),
245
+ # gr.update(visible=False),
246
+ # "",
247
+ # "⏳ Processing...",
248
+ # "\n".join(log_lines)
249
+ # )
250
+
251
+ # try:
252
+ # print(acc)
253
+ # rows = summarize_results(acc)
254
+ # all_rows.extend(rows)
255
+ # processed_accessions += 1 # ✅ count only successful accessions
256
+ # save_to_excel(all_rows, "", "", output_file_path, is_resume=False)
257
+ # log_lines.append(f"✅ Processed {acc} ({i+1}/{total})")
258
+ # except Exception as e:
259
+ # log_lines.append(f"❌ Failed to process {acc}: {e}")
260
+
261
+ # yield (
262
+ # make_html_table(all_rows),
263
+ # gr.update(visible=True),
264
+ # gr.update(visible=False),
265
+ # "",
266
+ # "⏳ Processing...",
267
+ # "\n".join(log_lines)
268
+ # )
269
+
270
+ # # Final update
271
+ # usage_text = ""
272
+
273
+ # if email.strip() and not email_tracked:
274
+ # # user_hash = hash_user_id(email)
275
+ # # usage_count = increment_usage(user_hash, len(all_rows))
276
+ # print("print(processed_accessions final) ",processed_accessions)
277
+ # usage_count = increment_usage(email, processed_accessions)
278
+ # usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
279
+ # elif not email.strip():
280
+ # usage_text = f"The limited accession is 50. The user has used {processed_accessions}, and only {50-processed_accessions} left."
281
+ # yield (
282
+ # make_html_table(all_rows),
283
+ # gr.update(visible=True),
284
+ # gr.update(value=output_file_path, visible=True),
285
+ # gr.update(value=usage_text, visible=True),
286
+ # "✅ Done",
287
+ # "\n".join(log_lines)
288
+ # )
289
+
290
+ def threaded_batch_runner(file=None, text="", email=""):
291
+ print("📧 EMAIL RECEIVED:", repr(email))
292
+ import tempfile
293
+ from mtdna_backend import (
294
+ extract_accessions_from_input,
295
+ summarize_results,
296
+ save_to_excel,
297
+ increment_usage,
298
+ )
299
+ import os
300
+
301
+ global_stop_flag.value = False # reset stop flag
302
+
303
+ tmp_dir = tempfile.mkdtemp()
304
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
305
+ #output_file_path = "/mnt/data/batch_output_live.xlsx"
306
+ all_rows = []
307
+ processed_accessions = 0 # ✅ track successful accessions
308
+ email_tracked = False
309
+ log_lines = []
310
+ if not email.strip():
311
+ output_file_path = None#"Write your email so that you can download the outputs."
312
+ log_lines.append("📥 Provide your email to receive a downloadable Excel report and get 20 more free queries.")
313
+ limited_acc = 30
314
+ if email.strip():
315
+ usage_count, max_allowed = increment_usage(email, processed_accessions)
316
+ if int(usage_count) >= int(max_allowed):
317
+ log_lines.append("❌ You have reached your quota. Please contact us to unlock more.")
318
+
319
+ # Minimal blank yield to trigger UI rendering
320
+ yield (
321
+ make_html_table([]),
322
+ gr.update(visible=True),
323
+ gr.update(visible=False),
324
+ gr.update(value="", visible=True),
325
+ "⛔️ Quota limit",
326
+ "⛔️ Quota limit"
327
+ )
328
+
329
+ # Actual warning frame
330
+ yield (
331
+ make_html_table([]),
332
+ gr.update(visible=False),
333
+ gr.update(visible=False),
334
+ gr.update(value="❌ You have reached your quota. Please contact us to unlock more.", visible=True),
335
+ "❌ Quota Exceeded",
336
+ "\n".join(log_lines)
337
+ )
338
+ return
339
+ limited_acc = int(max_allowed-usage_count)
340
+ # Step 1: Parse input
341
+ accessions, error = extract_accessions_from_input(file, text)
342
+ print("🧪 Accessions received:", accessions)
343
+ if error:
344
+ yield (
345
+ "", # output_table
346
+ gr.update(visible=False), # results_group
347
+ gr.update(visible=False), # download_file
348
+ "", # usage_display
349
+ "❌ Error", # status
350
+ str(error) # progress_box
351
+ )
352
+ return
353
+
354
+ total = len(accessions)
355
+ if total > limited_acc:
356
+ accessions = accessions[:limited_acc]
357
+ warning = f"⚠️ Only processing first {limited_acc} accessions."
358
+ else:
359
+ warning = f"✅ All {total} accessions will be processed."
360
+
361
+ # all_rows = []
362
+ # processed_accessions = 0 # ✅ track successful accessions
363
+ # email_tracked = False
364
+ # log_lines = []
365
+ # if not email.strip():
366
+ # output_file_path = None#"Write your email so that you can download the outputs."
367
+ # log_lines.append("📥 Provide your email to receive a downloadable Excel report and get 20 more free queries.")
368
+ # if email.strip():
369
+ # usage_count, max_allowed = increment_usage(email, processed_accessions)
370
+ # if int(usage_count) > int(max_allowed):
371
+ # log_lines.append("❌ You have reached your quota. Please contact us to unlock more.")
372
+
373
+ # # Minimal blank yield to trigger UI rendering
374
+ # yield (
375
+ # make_html_table([]),
376
+ # gr.update(visible=True),
377
+ # gr.update(visible=False),
378
+ # gr.update(value="", visible=True),
379
+ # "⛔️ Quota limit",
380
+ # "⛔️ Quota limit"
381
+ # )
382
+
383
+ # # Actual warning frame
384
+ # yield (
385
+ # make_html_table([]),
386
+ # gr.update(visible=False),
387
+ # gr.update(visible=False),
388
+ # gr.update(value="❌ You have reached your quota. Please contact us to unlock more.", visible=True),
389
+ # " Quota Exceeded",
390
+ # "\n".join(log_lines)
391
+ # )
392
+ # return
393
+
394
+
395
+ # Step 2: Loop through accessions
396
+ for i, acc in enumerate(accessions):
397
+ if global_stop_flag.value:
398
+ log_lines.append(f"🛑 Stopped at {acc} ({i+1}/{total})")
399
+ usage_text = ""
400
+
401
+ if email.strip() and not email_tracked:
402
+ print(f"🧪 increment_usage at STOP: {email=} {processed_accessions=}")
403
+ usage_count, max_allowed = increment_usage(email, processed_accessions)
404
+ email_tracked = True
405
+ usage_text = f"**{usage_count}**/{max_allowed} allowed samples used by this email."
406
+ #Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
407
+ else:
408
+ usage_text = f"The limited accession is 30. The user has used {processed_accessions}, and only {30 - processed_accessions} left."
409
+
410
+ yield (
411
+ make_html_table(all_rows),
412
+ gr.update(visible=True),
413
+ #gr.update(value=output_file_path, visible=True),
414
+ gr.update(value=output_file_path, visible=bool(output_file_path)),
415
+ gr.update(value=usage_text, visible=True),
416
+ "🛑 Stopped",
417
+ "\n".join(log_lines)
418
+ )
419
+ return
420
+
421
+ log_lines.append(f"[{i+1}/{total}] Processing {acc}")
422
+ yield (
423
+ make_html_table(all_rows),
424
+ gr.update(visible=True),
425
+ gr.update(visible=False),
426
+ "",
427
+ "⏳ Processing...",
428
+ "\n".join(log_lines)
429
+ )
430
+
431
+ try:
432
+ print("📄 Processing accession:", acc)
433
+ rows = summarize_results(acc)
434
+ all_rows.extend(rows)
435
+ processed_accessions += 1 # only count success
436
+ if email.strip():
437
+ save_to_excel(all_rows, "", "", output_file_path, is_resume=False)
438
+ log_lines.append(f"✅ Processed {acc} ({i+1}/{total})")
439
+ except Exception as e:
440
+ log_lines.append(f"❌ Failed to process {acc}: {e}")
441
+
442
+ yield (
443
+ make_html_table(all_rows),
444
+ gr.update(visible=True),
445
+ gr.update(visible=False),
446
+ "",
447
+ "⏳ Processing...",
448
+ "\n".join(log_lines)
449
+ )
450
+
451
+ # Step 3: Final usage update
452
+ usage_text = ""
453
+ if email.strip() and not email_tracked:
454
+ print(f"🧪 increment_usage at END: {email=} {processed_accessions=}")
455
+ usage_count, max_allowed = increment_usage(email, processed_accessions)
456
+ email_tracked = True
457
+ usage_text = f"**{usage_count}**/{max_allowed} allowed samples used by this email."
458
+ #Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
459
+ elif not email.strip():
460
+ usage_text = f"The limited accession is 30. The user has used {processed_accessions}, and only {30 - processed_accessions} left."
461
+
462
+ yield (
463
+ make_html_table(all_rows),
464
+ gr.update(visible=True),
465
+ #gr.update(value=output_file_path, visible=True),
466
+ gr.update(value=output_file_path, visible=bool(output_file_path)),
467
+ gr.update(value=usage_text, visible=True),
468
+ "✅ Done",
469
+ "\n".join(log_lines)
470
+ )
471
+
472
+ # SUBMIT REPORT UI
473
+ # 1. Google Sheets setup
474
+ def get_worksheet(sheet_name="Report"):
475
+ import os, json
476
+ import gspread
477
+ from oauth2client.service_account import ServiceAccountCredentials
478
+ try:
479
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
480
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
481
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
482
+ client = gspread.authorize(creds)
483
+ sheet = client.open(sheet_name).sheet1
484
+ return sheet
485
+ except Exception as e:
486
+ print(f"❌ Error loading Google Sheet '{sheet_name}':", e)
487
+ return None
488
+
489
+ # 2. Submit function to send report to the Google Sheet
490
+ def submit_report(report_text,user_email=""):
491
+ try:
492
+ sheet = get_worksheet()
493
+ # ✅ Parse the report_text (each line like 'ACCESSION: message')
494
+ lines = report_text.strip().split('\n')
495
+ user = ""
496
+ if user_email.strip():
497
+ user = user_email
498
+ for line in lines:
499
+ if ':' in line:
500
+ accession, message = line.split(':', 1)
501
+ sheet.append_row([accession.strip(), message.strip(), user.strip()])
502
+ return "✅ Report submitted successfully!"
503
+ except Exception as e:
504
+ return f"❌ Error submitting report: {str(e)}"
505
+ def show_report_ui():
506
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
507
+
508
+ def handle_submission(text,user_email):
509
+ msg = submit_report(text, user_email)
510
+ return gr.update(value=msg, visible=True), gr.update(visible=False), gr.update(visible=False)
511
+ # def threaded_batch_runner(file=None, text="", email=""):
512
+ # global_stop_flag.value = False
513
+
514
+ # # Dummy test output that matches expected schema
515
+ # return (
516
+ # "<div>✅ Dummy output table</div>", # HTML string
517
+ # gr.update(visible=True), # Group visibility
518
+ # gr.update(visible=False), # Download file
519
+ # "**0** samples used.", # Markdown
520
+ # "✅ Done", # Status string
521
+ # "Processing finished." # Progress string
522
+ # )
523
+
524
+
525
+ # def classify_mulAcc(file, text, resume, email, log_callback=None, log_collector=None):
526
+ # stop_flag.value = False
527
+ # return threaded_batch_runner(file, text, resume, email, status, stop_flag, log_callback=log_callback, log_collector=log_collector)
528
+
529
+
530
+ def make_html_table(rows):
531
+ # html = """
532
+ # <div style='overflow-x: auto; padding: 10px;'>
533
+ # <div style='max-height: 400px; overflow-y: auto; border: 1px solid #444; border-radius: 8px;'>
534
+ # <table style='width:100%; border-collapse: collapse; table-layout: auto; font-size: 14px; color: #f1f1f1; background-color: #1e1e1e;'>
535
+ # <thead style='position: sticky; top: 0; background-color: #2c2c2c; z-index: 1;'>
536
+ # <tr>
537
+ # """
538
+ html = """
539
+ <div style='overflow-x: auto; padding: 10px;'>
540
+ <div style='max-height: 400px; overflow-y: auto; border: 1px solid #ccc; border-radius: 8px;'>
541
+ <table style='width:100%; border-collapse: collapse; table-layout: auto; font-size: 14px; color: inherit; background-color: inherit;'>
542
+ """
543
+
544
+ headers = ["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]
545
+ html += "".join(
546
+ f"<th style='padding: 10px; border: 1px solid #555; text-align: left; white-space: nowrap;'>{h}</th>"
547
+ for h in headers
548
+ )
549
+ html += "</tr></thead><tbody>"
550
+
551
+ for row in rows:
552
+ html += "<tr>"
553
+ for i, col in enumerate(row):
554
+ header = headers[i]
555
+ style = "padding: 10px; border: 1px solid #555; vertical-align: top;"
556
+
557
+ # For specific columns like Haplogroup, force nowrap
558
+ if header in ["Country Explanation", "Sample Type Explanation"]:
559
+ style += " max-width: 400px; word-wrap: break-word; white-space: normal;"
560
+ elif header in ["Sample ID", "Predicted Country", "Predicted Sample Type", "Time cost"]:
561
+ style += " white-space: nowrap; text-overflow: ellipsis; max-width: 200px; overflow: hidden;"
562
+
563
+ # if header == "Sources" and isinstance(col, str) and col.strip().lower().startswith("http"):
564
+ # col = f"<a href='{col}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{col}</a>"
565
+
566
+ #html += f"<td style='{style}'>{col}</td>"
567
+ if header == "Sources" and isinstance(col, str):
568
+ links = [f"<a href='{url.strip()}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{url.strip()}</a>" for url in col.strip().split("\n") if url.strip()]
569
+ col = "- "+"<br>- ".join(links)
570
+ elif isinstance(col, str):
571
+ # lines = []
572
+ # for line in col.split("\n"):
573
+ # line = line.strip()
574
+ # if not line:
575
+ # continue
576
+ # if line.lower().startswith("rag_llm-"):
577
+ # content = line[len("rag_llm-"):].strip()
578
+ # line = f"{content} (Method: RAG_LLM)"
579
+ # lines.append(f"- {line}")
580
+ col = col.replace("\n", "<br>")
581
+ #col = col.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
582
+ #col = "<br>".join(lines)
583
+
584
+ html += f"<td style='{style}'>{col}</td>"
585
+ html += "</tr>"
586
+
587
+ html += "</tbody></table></div></div>"
588
+ return html
589
+
590
+
591
+ # def reset_fields():
592
+ # global_stop_flag.value = False # 💡 Add this to reset the flag
593
+ # return (
594
+ # #gr.update(value=""), # single_accession
595
+ # gr.update(value=""), # raw_text
596
+ # gr.update(value=None), # file_upload
597
+ # #gr.update(value=None), # resume_file
598
+ # #gr.update(value="Single Accession"), # inputMode
599
+ # gr.update(value=[], visible=True), # output_table
600
+ # # gr.update(value="", visible=True), # output_summary
601
+ # # gr.update(value="", visible=True), # output_flag
602
+ # gr.update(visible=False), # status
603
+ # gr.update(visible=False), # results_group
604
+ # gr.update(value="", visible=False), # usage_display
605
+ # gr.update(value="", visible=False), # progress_box
606
+ # )
607
+ def reset_fields():
608
+ global_stop_flag.value = False # Reset the stop flag
609
+
610
+ return (
611
+ gr.update(value=""), # raw_text
612
+ gr.update(value=None), # file_upload
613
+ gr.update(value=[], visible=True), # output_table
614
+ gr.update(value="", visible=True), # status — reset and make visible again
615
+ gr.update(visible=False), # results_group
616
+ gr.update(value="", visible=True), # usage_display — reset and make visible again
617
+ gr.update(value="", visible=True), # progress_box — reset AND visible!
618
+ # report-related reset below
619
+ gr.update(value="", visible=False), # report_textbox
620
+ gr.update(visible=False), # submit_report_button
621
+ gr.update(value="", visible=False), # status_report
622
+ )
623
+ #inputMode.change(fn=toggle_input_mode, inputs=inputMode, outputs=[single_input_group, batch_input_group])
624
+ #run_button.click(fn=classify_with_loading, inputs=[], outputs=[status])
625
+ # run_button.click(
626
+ # fn=classify_dynamic,
627
+ # inputs=[single_accession, file_upload, raw_text, resume_file,user_email,inputMode],
628
+ # outputs=[output_table,
629
+ # #output_summary, output_flag,
630
+ # results_group, download_file, usage_display,status, progress_box]
631
+ # )
632
+
633
+ # run_button.click(
634
+ # fn=threaded_batch_runner,
635
+ # #inputs=[file_upload, raw_text, resume_file, user_email],
636
+ # inputs=[file_upload, raw_text, user_email],
637
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box]
638
+ # )
639
+ # run_button.click(
640
+ # fn=threaded_batch_runner,
641
+ # inputs=[file_upload, raw_text, user_email],
642
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
643
+ # every=0.5 # <-- this tells Gradio to expect streaming
644
+ # )
645
+ # output_table = gr.HTML()
646
+ # results_group = gr.Group(visible=False)
647
+ # download_file = gr.File(visible=False)
648
+ # usage_display = gr.Markdown(visible=False)
649
+ # status = gr.Markdown(visible=False)
650
+ # progress_box = gr.Textbox(visible=False)
651
+
652
+ # run_button.click(
653
+ # fn=threaded_batch_runner,
654
+ # inputs=[file_upload, raw_text, user_email],
655
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
656
+ # every=0.5, # streaming enabled
657
+ # show_progress="full"
658
+ # )
659
+
660
+ # interface.stream(
661
+ # fn=threaded_batch_runner,
662
+ # inputs=[file_upload, raw_text, user_email],
663
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
664
+ # trigger=run_button,
665
+ # every=0.5,
666
+ # show_progress="full",
667
+ # )
668
+ interface.queue() # No arguments here!
669
+
670
+ run_button.click(
671
+ fn=threaded_batch_runner,
672
+ inputs=[file_upload, raw_text, user_email],
673
+ outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
674
+ concurrency_limit=1, # ✅ correct in Gradio 5.x
675
+ queue=True, # ✅ ensure the queue is used
676
+ #every=0.5
677
+ )
678
+
679
+
680
+
681
+
682
+ stop_button.click(fn=stop_batch, inputs=[], outputs=[status])
683
+
684
+ # reset_button.click(
685
+ # #fn=reset_fields,
686
+ # fn=lambda: (
687
+ # gr.update(value=""), gr.update(value=""), gr.update(value=None), gr.update(value=None), gr.update(value="Single Accession"),
688
+ # gr.update(value=[], visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="", visible=False)
689
+ # ),
690
+ # inputs=[],
691
+ # outputs=[
692
+ # single_accession, raw_text, file_upload, resume_file,inputMode,
693
+ # output_table,# output_summary, output_flag,
694
+ # status, results_group, usage_display, progress_box
695
+ # ]
696
+ # )
697
+ #stop_button.click(fn=lambda sf: (gr.update(value="❌ Stopping...", visible=True), setattr(sf, "value", True) or sf), inputs=[gr.State(stop_flag)], outputs=[status, gr.State(stop_flag)])
698
+
699
+ reset_button.click(
700
+ fn=reset_fields,
701
+ inputs=[],
702
+ #outputs=[raw_text, file_upload, resume_file, output_table, status, results_group, usage_display, progress_box]
703
+ outputs=[raw_text, file_upload, output_table, status, results_group, usage_display, progress_box,
704
+ report_textbox,
705
+ submit_report_button,
706
+ status_report]
707
+ )
708
+
709
+ # download_button.click(
710
+ # fn=mtdna_backend.save_batch_output,
711
+ # #inputs=[output_table, output_summary, output_flag, output_type],
712
+ # inputs=[output_table, output_type],
713
+ # outputs=[download_file])
714
+
715
+ # submit_feedback.click(
716
+ # fn=mtdna_backend.store_feedback_to_google_sheets,
717
+ # inputs=[single_accession, q1, q2, contact], outputs=feedback_status
718
+ # )
719
+ report_button.click(fn=show_report_ui, outputs=[report_textbox, submit_report_button, status_report])
720
+ submit_report_button.click(fn=handle_submission, inputs=[report_textbox, user_email], outputs=[status_report, report_textbox, submit_report_button])
721
+
722
+ submit_feedback.click(
723
+ fn=mtdna_backend.store_feedback_to_google_sheets,
724
+ inputs=[raw_text, q1, q2, contact],
725
+ outputs=[feedback_status]
726
+ )
727
+ gr.HTML("""
728
+ <style>
729
+ body, html {
730
+ background-color: #121212 !important;
731
+ color: #ffffff !important;
732
+ }
733
+
734
+ .gradio-container, .gr-block, .gr-box, textarea, input, select, .prose, .prose * {
735
+ background-color: #1e1e1e !important;
736
+ color: #ffffff !important;
737
+ border-color: #333 !important;
738
+ }
739
+
740
+ textarea::placeholder,
741
+ input::placeholder {
742
+ color: #aaa !important;
743
+ }
744
+
745
+ button {
746
+ background-color: #2d2d2d !important;
747
+ color: #fff !important;
748
+ border: 1px solid #444 !important;
749
+ }
750
+
751
+ a {
752
+ color: #4ea1f3 !important;
753
+ }
754
+ </style>
755
+ """)
756
+
757
+ # # Custom CSS styles
758
+ # gr.HTML("""
759
+ # <style>
760
+ # /* Ensures both sections are equally spaced with the same background size */
761
+ # #output-summary, #output-flag {
762
+ # background-color: #f0f4f8; /* Light Grey for both */
763
+ # padding: 20px;
764
+ # border-radius: 10px;
765
+ # margin-top: 10px;
766
+ # width: 100%; /* Ensure full width */
767
+ # min-height: 150px; /* Ensures both have a minimum height */
768
+ # box-sizing: border-box; /* Prevents padding from increasing size */
769
+ # display: flex;
770
+ # flex-direction: column;
771
+ # justify-content: space-between;
772
+ # }
773
+
774
+ # /* Specific background colors */
775
+ # #output-summary {
776
+ # background-color: #434a4b;
777
+ # }
778
+
779
+ # #output-flag {
780
+ # background-color: #141616;
781
+ # }
782
+
783
+ # /* Ensuring they are in a row and evenly spaced */
784
+ # .gradio-row {
785
+ # display: flex;
786
+ # justify-content: space-between;
787
+ # width: 100%;
788
+ # }
789
+ # </style>
790
+ # """)
791
+
792
+
793
  interface.launch(share=True,debug=True)
data_preprocess.py CHANGED
@@ -1,669 +1,746 @@
1
- import re
2
- import os
3
- #import streamlit as st
4
- import subprocess
5
- import re
6
- from Bio import Entrez
7
- from docx import Document
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- import pandas as pd
19
- import model
20
- import pipeline
21
- import tempfile
22
- import nltk
23
- nltk.download('punkt_tab')
24
- def download_excel_file(url, save_path="temp.xlsx"):
25
- if "view.officeapps.live.com" in url:
26
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
- real_url = urllib.parse.unquote(parsed_url["src"][0])
28
- response = requests.get(real_url)
29
- with open(save_path, "wb") as f:
30
- f.write(response.content)
31
- return save_path
32
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
- response = requests.get(url)
34
- response.raise_for_status() # Raises error if download fails
35
- with open(save_path, "wb") as f:
36
- f.write(response.content)
37
- print(len(response.content))
38
- return save_path
39
- else:
40
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
- return url
42
- def extract_text(link,saveFolder):
43
- text = ""
44
- name = link.split("/")[-1]
45
- #file_path = Path(saveFolder) / name
46
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
47
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
48
-
49
- # pdf
50
- if link.endswith(".pdf"):
51
- # if file_path.is_file():
52
- # link = saveFolder + "/" + name
53
- # print("File exists.")
54
- p = pdf.PDF(local_temp_path, saveFolder)
55
- #p = pdf.PDF(link,saveFolder)
56
- text = p.extractTextWithPDFReader()
57
- #text_exclude_table = p.extract_text_excluding_tables()
58
- # worddoc
59
- elif link.endswith(".doc") or link.endswith(".docx"):
60
- d = wordDoc.wordDoc(local_temp_path,saveFolder)
61
- text = d.extractTextByPage()
62
- # html
63
- if link.split(".")[-1].lower() not in "xlsx":
64
- if "http" in link or "html" in link:
65
- html = extractHTML.HTML("",link)
66
- text = html.getListSection() # the text already clean
67
- return text
68
- def extract_table(link,saveFolder):
69
- table = []
70
- name = link.split("/")[-1]
71
- #file_path = Path(saveFolder) / name
72
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
73
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
74
- # pdf
75
- if link.endswith(".pdf"):
76
- # if file_path.is_file():
77
- # link = saveFolder + "/" + name
78
- # print("File exists.")
79
- p = pdf.PDF(local_temp_path,saveFolder)
80
- table = p.extractTable()
81
- # worddoc
82
- elif link.endswith(".doc") or link.endswith(".docx"):
83
- d = wordDoc.wordDoc(local_temp_path,saveFolder)
84
- table = d.extractTableAsList()
85
- # excel
86
- elif link.split(".")[-1].lower() in "xlsx":
87
- # download excel file if it not downloaded yet
88
- savePath = saveFolder +"/"+ link.split("/")[-1]
89
- excelPath = download_excel_file(link, savePath)
90
- try:
91
- #xls = pd.ExcelFile(excelPath)
92
- xls = pd.ExcelFile(local_temp_path)
93
- table_list = []
94
- for sheet_name in xls.sheet_names:
95
- df = pd.read_excel(xls, sheet_name=sheet_name)
96
- cleaned_table = df.fillna("").astype(str).values.tolist()
97
- table_list.append(cleaned_table)
98
- table = table_list
99
- except Exception as e:
100
- print("❌ Failed to extract tables from Excel:", e)
101
- # html
102
- elif "http" in link or "html" in link:
103
- html = extractHTML.HTML("",link)
104
- table = html.extractTable() # table is a list
105
- table = clean_tables_format(table)
106
- return table
107
-
108
- def clean_tables_format(tables):
109
- """
110
- Ensures all tables are in consistent format: List[List[List[str]]]
111
- Cleans by:
112
- - Removing empty strings and rows
113
- - Converting all cells to strings
114
- - Handling DataFrames and list-of-lists
115
- """
116
- cleaned = []
117
- if tables:
118
- for table in tables:
119
- standardized = []
120
-
121
- # Case 1: Pandas DataFrame
122
- if isinstance(table, pd.DataFrame):
123
- table = table.fillna("").astype(str).values.tolist()
124
-
125
- # Case 2: List of Lists
126
- if isinstance(table, list) and all(isinstance(row, list) for row in table):
127
- for row in table:
128
- filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
129
- if filtered_row:
130
- standardized.append(filtered_row)
131
-
132
- if standardized:
133
- cleaned.append(standardized)
134
-
135
- return cleaned
136
-
137
- import json
138
- def normalize_text_for_comparison(s: str) -> str:
139
- """
140
- Normalizes text for robust comparison by:
141
- 1. Converting to lowercase.
142
- 2. Replacing all types of newlines with a single consistent newline (\n).
143
- 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
144
- 4. Stripping leading/trailing whitespace from the entire string.
145
- """
146
- s = s.lower()
147
- s = s.replace('\r\n', '\n') # Handle Windows newlines
148
- s = s.replace('\r', '\n') # Handle Mac classic newlines
149
-
150
- # Replace sequences of whitespace (including multiple newlines) with a single space
151
- # This might be too aggressive if you need to preserve paragraph breaks,
152
- # but good for exact word-sequence matching.
153
- s = re.sub(r'\s+', ' ', s)
154
-
155
- return s.strip()
156
- def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
157
- """
158
- Merge cleaned text and table into one string for LLM input.
159
- - Avoids duplicating tables already in text
160
- - Extracts only relevant rows from large tables
161
- - Skips or saves oversized tables
162
- """
163
- import importlib
164
- json = importlib.import_module("json")
165
-
166
- def estimate_tokens(text_str):
167
- try:
168
- enc = tiktoken.get_encoding(tokenizer)
169
- return len(enc.encode(text_str))
170
- except:
171
- return len(text_str) // 4 # Fallback estimate
172
-
173
- def is_table_relevant(table, keywords, accession_id=None):
174
- flat = " ".join(" ".join(row).lower() for row in table)
175
- if accession_id and accession_id.lower() in flat:
176
- return True
177
- return any(kw.lower() in flat for kw in keywords)
178
- preview, preview1 = "",""
179
- llm_input = "## Document Text\n" + text.strip() + "\n"
180
- clean_text = normalize_text_for_comparison(text)
181
-
182
- if tables:
183
- for idx, table in enumerate(tables):
184
- keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
185
- if accession_id: keywords += [accession_id.lower()]
186
- if isolate: keywords += [isolate.lower()]
187
- if is_table_relevant(table, keywords, accession_id):
188
- if len(table) > 0:
189
- for tab in table:
190
- preview = " ".join(tab) if tab else ""
191
- preview1 = "\n".join(tab) if tab else ""
192
- clean_preview = normalize_text_for_comparison(preview)
193
- clean_preview1 = normalize_text_for_comparison(preview1)
194
- if clean_preview not in clean_text:
195
- if clean_preview1 not in clean_text:
196
- table_str = json.dumps([tab], indent=2)
197
- llm_input += f"## Table {idx+1}\n{table_str}\n"
198
- return llm_input.strip()
199
-
200
- def preprocess_document(link, saveFolder, accession=None, isolate=None):
201
- try:
202
- text = extract_text(link, saveFolder)
203
- except: text = ""
204
- try:
205
- tables = extract_table(link, saveFolder)
206
- except: tables = []
207
- if accession: accession = accession
208
- if isolate: isolate = isolate
209
- try:
210
- final_input = merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
211
- except: final_input = ""
212
- return text, tables, final_input
213
-
214
- def extract_sentences(text):
215
- sentences = re.split(r'(?<=[.!?])\s+', text)
216
- return [s.strip() for s in sentences if s.strip()]
217
-
218
- def is_irrelevant_number_sequence(text):
219
- if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
220
- return False
221
- word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
222
- number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
223
- total_tokens = len(re.findall(r'\S+', text))
224
- if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
225
- return True
226
- elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
227
- return True
228
- return False
229
-
230
- def remove_isolated_single_digits(sentence):
231
- tokens = sentence.split()
232
- filtered_tokens = []
233
- for token in tokens:
234
- if token == '0' or token == '1':
235
- pass
236
- else:
237
- filtered_tokens.append(token)
238
- return ' '.join(filtered_tokens).strip()
239
-
240
- def get_contextual_sentences_BFS(text_content, keyword, depth=2):
241
- def extract_codes(sentence):
242
- # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
243
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
244
- sentences = extract_sentences(text_content)
245
- relevant_sentences = set()
246
- initial_keywords = set()
247
-
248
- # Define a regex to capture codes like A1YU101 or KM1
249
- # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
250
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
251
-
252
- # Attempt to parse the keyword into its prefix and numerical part using re.search
253
- keyword_match = code_pattern.search(keyword)
254
-
255
- keyword_prefix = None
256
- keyword_num = None
257
-
258
- if keyword_match:
259
- keyword_prefix = keyword_match.group(1).lower()
260
- keyword_num = int(keyword_match.group(2))
261
-
262
- for sentence in sentences:
263
- sentence_added = False
264
-
265
- # 1. Check for exact match of the keyword
266
- if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
267
- relevant_sentences.add(sentence.strip())
268
- initial_keywords.add(keyword.lower())
269
- sentence_added = True
270
-
271
- # 2. Check for range patterns (e.g., A1YU101-A1YU137)
272
- # The range pattern should be broad enough to capture the full code string within the range.
273
- range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
274
- range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
275
-
276
- for r_match in range_matches:
277
- start_code_str = r_match.group(1)
278
- end_code_str = r_match.group(2)
279
-
280
- # CRITICAL FIX: Use code_pattern.search for start_match and end_match
281
- start_match = code_pattern.search(start_code_str)
282
- end_match = code_pattern.search(end_code_str)
283
-
284
- if keyword_prefix and keyword_num is not None and start_match and end_match:
285
- start_prefix = start_match.group(1).lower()
286
- end_prefix = end_match.group(1).lower()
287
- start_num = int(start_match.group(2))
288
- end_num = int(end_match.group(2))
289
-
290
- # Check if the keyword's prefix matches and its number is within the range
291
- if keyword_prefix == start_prefix and \
292
- keyword_prefix == end_prefix and \
293
- start_num <= keyword_num <= end_num:
294
- relevant_sentences.add(sentence.strip())
295
- initial_keywords.add(start_code_str.lower())
296
- initial_keywords.add(end_code_str.lower())
297
- sentence_added = True
298
- break # Only need to find one matching range per sentence
299
-
300
- # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
301
- # to initial_keywords to ensure graph traversal from related terms.
302
- if sentence_added:
303
- for word in extract_codes(sentence):
304
- initial_keywords.add(word.lower())
305
-
306
-
307
- # Build word_to_sentences mapping for all sentences
308
- word_to_sentences = {}
309
- for sent in sentences:
310
- codes_in_sent = set(extract_codes(sent))
311
- for code in codes_in_sent:
312
- word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
313
-
314
-
315
- # Build the graph
316
- graph = {}
317
- for sent in sentences:
318
- codes = set(extract_codes(sent))
319
- for word1 in codes:
320
- word1_lower = word1.lower()
321
- graph.setdefault(word1_lower, set())
322
- for word2 in codes:
323
- word2_lower = word2.lower()
324
- if word1_lower != word2_lower:
325
- graph[word1_lower].add(word2_lower)
326
-
327
-
328
- # Perform BFS/graph traversal
329
- queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
330
- visited_words = set(initial_keywords)
331
-
332
- while queue:
333
- current_word, level = queue.pop(0)
334
- if level >= depth:
335
- continue
336
-
337
- relevant_sentences.update(word_to_sentences.get(current_word, []))
338
-
339
- for neighbor in graph.get(current_word, []):
340
- if neighbor not in visited_words:
341
- visited_words.add(neighbor)
342
- queue.append((neighbor, level + 1))
343
-
344
- final_sentences = set()
345
- for sentence in relevant_sentences:
346
- if not is_irrelevant_number_sequence(sentence):
347
- processed_sentence = remove_isolated_single_digits(sentence)
348
- if processed_sentence:
349
- final_sentences.add(processed_sentence)
350
-
351
- return "\n".join(sorted(list(final_sentences)))
352
-
353
-
354
-
355
- def get_contextual_sentences_DFS(text_content, keyword, depth=2):
356
- sentences = extract_sentences(text_content)
357
-
358
- # Build word-to-sentences mapping
359
- word_to_sentences = {}
360
- for sent in sentences:
361
- words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
362
- for word in words_in_sent:
363
- word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
364
-
365
- # Function to extract codes in a sentence
366
- def extract_codes(sentence):
367
- # Only codes like 'KSK1', 'MG272794', not pure numbers
368
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
369
-
370
- # DFS with priority based on distance to keyword and early stop if country found
371
- def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
372
- country = "unknown"
373
- if current_depth > max_depth:
374
- return country, False
375
-
376
- if current_word not in word_to_sentences:
377
- return country, False
378
-
379
- for sentence in word_to_sentences[current_word]:
380
- if sentence == parent_sentence:
381
- continue # avoid reusing the same sentence
382
-
383
- collected_sentences.add(sentence)
384
-
385
- #print("current_word:", current_word)
386
- small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
387
- #print(small_sen)
388
- country = model.get_country_from_text(small_sen)
389
- #print("small context country:", country)
390
- if country.lower() != "unknown":
391
- return country, True
392
- else:
393
- country = model.get_country_from_text(sentence)
394
- #print("full sentence country:", country)
395
- if country.lower() != "unknown":
396
- return country, True
397
-
398
- codes_in_sentence = extract_codes(sentence)
399
- idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
400
- if idx is None:
401
- continue
402
-
403
- sorted_children = sorted(
404
- [code for code in codes_in_sentence if code.lower() not in visited_words],
405
- key=lambda x: (abs(codes_in_sentence.index(x) - idx),
406
- 0 if codes_in_sentence.index(x) > idx else 1)
407
- )
408
-
409
- #print("sorted_children:", sorted_children)
410
- for child in sorted_children:
411
- child_lower = child.lower()
412
- if child_lower not in visited_words:
413
- visited_words.add(child_lower)
414
- country, should_stop = dfs_traverse(
415
- child_lower, current_depth + 1, max_depth,
416
- visited_words, collected_sentences, parent_sentence=sentence
417
- )
418
- if should_stop:
419
- return country, True
420
-
421
- return country, False
422
-
423
- # Begin DFS
424
- collected_sentences = set()
425
- visited_words = set([keyword.lower()])
426
- country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
427
-
428
- # Filter irrelevant sentences
429
- final_sentences = set()
430
- for sentence in collected_sentences:
431
- if not is_irrelevant_number_sequence(sentence):
432
- processed = remove_isolated_single_digits(sentence)
433
- if processed:
434
- final_sentences.add(processed)
435
- if not final_sentences:
436
- return country, text_content
437
- return country, "\n".join(sorted(list(final_sentences)))
438
-
439
- # Helper function for normalizing text for overlap comparison
440
- def normalize_for_overlap(s: str) -> str:
441
- s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
442
- s = re.sub(r'\s+', ' ', s).strip()
443
- return s
444
-
445
- def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
446
- if not text1: return text2
447
- if not text2: return text1
448
-
449
- # Case 1: text2 is fully contained in text1 or vice-versa
450
- if text2 in text1:
451
- return text1
452
- if text1 in text2:
453
- return text2
454
-
455
- # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
456
- # This is what your function was primarily designed for.
457
- # It looks for the overlap at the "junction" of text1 and text2.
458
-
459
- max_junction_overlap = 0
460
- for i in range(min(len(text1), len(text2)), 0, -1):
461
- suffix1 = text1[-i:]
462
- prefix2 = text2[:i]
463
- # Prioritize exact match, then normalized match
464
- if suffix1 == prefix2:
465
- max_junction_overlap = i
466
- break
467
- elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
468
- max_junction_overlap = i
469
- break # Take the first (longest) normalized match
470
-
471
- if max_junction_overlap > 0:
472
- merged_text = text1 + text2[max_junction_overlap:]
473
- return re.sub(r'\s+', ' ', merged_text).strip()
474
-
475
- # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
476
- # This addresses your specific test case where the overlap is at the very beginning of both strings.
477
- # This is often used when trying to deduplicate content that shares a common start.
478
-
479
- longest_common_prefix_len = 0
480
- min_len = min(len(text1), len(text2))
481
- for i in range(min_len):
482
- if text1[i] == text2[i]:
483
- longest_common_prefix_len = i + 1
484
- else:
485
- break
486
-
487
- # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
488
- # AND the remaining parts are distinct, then apply this merge.
489
- # This is a heuristic and might need fine-tuning.
490
- if longest_common_prefix_len > 0 and \
491
- text1[longest_common_prefix_len:].strip() and \
492
- text2[longest_common_prefix_len:].strip():
493
-
494
- # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
495
- # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
496
- # common prefix is "Hi, I am Vy."
497
- # Remaining text1: " Nice to meet you."
498
- # Remaining text2: " Goodbye Vy."
499
- # So we merge common_prefix + remaining_text1 + remaining_text2
500
-
501
- common_prefix_str = text1[:longest_common_prefix_len]
502
- remainder_text1 = text1[longest_common_prefix_len:]
503
- remainder_text2 = text2[longest_common_prefix_len:]
504
-
505
- merged_text = common_prefix_str + remainder_text1 + remainder_text2
506
- return re.sub(r'\s+', ' ', merged_text).strip()
507
-
508
-
509
- # If neither specific overlap type is found, just concatenate
510
- merged_text = text1 + text2
511
- return re.sub(r'\s+', ' ', merged_text).strip()
512
-
513
- from docx import Document
514
- from pipeline import upload_file_to_drive
515
- # def save_text_to_docx(text_content: str, file_path: str):
516
- # """
517
- # Saves a given text string into a .docx file.
518
-
519
- # Args:
520
- # text_content (str): The text string to save.
521
- # file_path (str): The full path including the filename where the .docx file will be saved.
522
- # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
523
- # """
524
- # try:
525
- # document = Document()
526
-
527
- # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
528
- # for paragraph_text in text_content.split('\n'):
529
- # document.add_paragraph(paragraph_text)
530
-
531
- # document.save(file_path)
532
- # print(f"Text successfully saved to '{file_path}'")
533
- # except Exception as e:
534
- # print(f"Error saving text to docx file: {e}")
535
- # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
536
- # """
537
- # Saves a given text string into a .docx file locally, then uploads to Google Drive.
538
-
539
- # Args:
540
- # text_content (str): The text string to save.
541
- # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
542
- # drive_folder_id (str): Google Drive folder ID where to upload the file.
543
- # """
544
- # try:
545
- # # Save to temporary local path first
546
- # print("file name: ", filename)
547
- # print("length text content: ", len(text_content))
548
- # local_path = os.path.join(tempfile.gettempdir(), filename)
549
- # document = Document()
550
- # for paragraph_text in text_content.split('\n'):
551
- # document.add_paragraph(paragraph_text)
552
- # document.save(local_path)
553
- # print(f"✅ Text saved locally to: {local_path}")
554
-
555
- # # ✅ Upload to Drive
556
- # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
557
- # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
558
-
559
- # except Exception as e:
560
- # print(f"❌ Error saving or uploading DOCX: {e}")
561
- def save_text_to_docx(text_content: str, full_local_path: str):
562
- document = Document()
563
- for paragraph_text in text_content.split('\n'):
564
- document.add_paragraph(paragraph_text)
565
- document.save(full_local_path)
566
- print(f"✅ Saved DOCX locally: {full_local_path}")
567
-
568
-
569
-
570
- '''2 scenerios:
571
- - quick look then found then deepdive and directly get location then stop
572
- - quick look then found then deepdive but not find location then hold the related words then
573
- look another files iteratively for each related word and find location and stop'''
574
- def extract_context(text, keyword, window=500):
575
- # firstly try accession number
576
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
577
-
578
- # Attempt to parse the keyword into its prefix and numerical part using re.search
579
- keyword_match = code_pattern.search(keyword)
580
-
581
- keyword_prefix = None
582
- keyword_num = None
583
-
584
- if keyword_match:
585
- keyword_prefix = keyword_match.group(1).lower()
586
- keyword_num = int(keyword_match.group(2))
587
- text = text.lower()
588
- idx = text.find(keyword.lower())
589
- if idx == -1:
590
- if keyword_prefix:
591
- idx = text.find(keyword_prefix)
592
- if idx == -1:
593
- return "Sample ID not found."
594
- return text[max(0, idx-window): idx+window]
595
- return text[max(0, idx-window): idx+window]
596
- def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
597
- cache = {}
598
- country = "unknown"
599
- output = ""
600
- tem_output, small_output = "",""
601
- keyword_appear = (False,"")
602
- keywords = []
603
- if isolate: keywords.append(isolate)
604
- if accession: keywords.append(accession)
605
- for f in filePaths:
606
- # scenerio 1: direct location: truncate the context and then use qa model?
607
- if keywords:
608
- for keyword in keywords:
609
- text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
610
- if keyword in final_input:
611
- context = extract_context(final_input, keyword)
612
- # quick look if country already in context and if yes then return
613
- country = model.get_country_from_text(context)
614
- if country != "unknown":
615
- return country, context, final_input
616
- else:
617
- country = model.get_country_from_text(final_input)
618
- if country != "unknown":
619
- return country, context, final_input
620
- else: # might be cross-ref
621
- keyword_appear = (True, f)
622
- cache[f] = context
623
- small_output = merge_texts_skipping_overlap(output, context) + "\n"
624
- chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
625
- countryBFS = model.get_country_from_text(chunkBFS)
626
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
627
- output = merge_texts_skipping_overlap(output, final_input)
628
- if countryDFS != "unknown" and countryBFS != "unknown":
629
- if len(chunkDFS) <= len(chunkBFS):
630
- return countryDFS, chunkDFS, output
631
- else:
632
- return countryBFS, chunkBFS, output
633
- else:
634
- if countryDFS != "unknown":
635
- return countryDFS, chunkDFS, output
636
- if countryBFS != "unknown":
637
- return countryBFS, chunkBFS, output
638
- else:
639
- # scenerio 2:
640
- '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
641
- but if we look at file 1 first then maybe we can have lookup dict which country
642
- such as Thailand as the key and its re'''
643
- cache[f] = final_input
644
- if keyword_appear[0] == True:
645
- for c in cache:
646
- if c!=keyword_appear[1]:
647
- if cache[c].lower() not in output.lower():
648
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
649
- chunkBFS = get_contextual_sentences_BFS(output, keyword)
650
- countryBFS = model.get_country_from_text(chunkBFS)
651
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
652
- if countryDFS != "unknown" and countryBFS != "unknown":
653
- if len(chunkDFS) <= len(chunkBFS):
654
- return countryDFS, chunkDFS, output
655
- else:
656
- return countryBFS, chunkBFS, output
657
- else:
658
- if countryDFS != "unknown":
659
- return countryDFS, chunkDFS, output
660
- if countryBFS != "unknown":
661
- return countryBFS, chunkBFS, output
662
- else:
663
- if cache[f].lower() not in output.lower():
664
- output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
665
- if len(output) == 0 or keyword_appear[0]==False:
666
- for c in cache:
667
- if cache[c].lower() not in output.lower():
668
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  return country, "", output
 
1
+ import re
2
+ import os
3
+ #import streamlit as st
4
+ import subprocess
5
+ import re
6
+ from Bio import Entrez
7
+ from docx import Document
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ #from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ import pandas as pd
19
+ import model
20
+ import pipeline
21
+ import tempfile
22
+ import nltk
23
+ nltk.download('punkt_tab')
24
+ def download_excel_file(url, save_path="temp.xlsx"):
25
+ if "view.officeapps.live.com" in url:
26
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
28
+ response = requests.get(real_url)
29
+ with open(save_path, "wb") as f:
30
+ f.write(response.content)
31
+ return save_path
32
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
+ response = requests.get(url)
34
+ response.raise_for_status() # Raises error if download fails
35
+ with open(save_path, "wb") as f:
36
+ f.write(response.content)
37
+ print(len(response.content))
38
+ return save_path
39
+ else:
40
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
+ return url
42
+ def extract_text(link,saveFolder):
43
+ try:
44
+ text = ""
45
+ name = link.split("/")[-1]
46
+ print("name: ", name)
47
+ #file_path = Path(saveFolder) / name
48
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
49
+ print("this is local temp path: ", local_temp_path)
50
+ if os.path.exists(local_temp_path):
51
+ input_to_class = local_temp_path
52
+ print("exist")
53
+ else:
54
+ #input_to_class = link # Let the class handle downloading
55
+ # 1. Check if file exists in shared Google Drive folder
56
+ file_id = pipeline.find_drive_file(name, saveFolder)
57
+ if file_id:
58
+ print("📥 Downloading from Google Drive...")
59
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
60
+ else:
61
+ print("🌐 Downloading from web link...")
62
+ response = requests.get(link)
63
+ with open(local_temp_path, 'wb') as f:
64
+ f.write(response.content)
65
+ print("✅ Saved locally.")
66
+
67
+ # 2. Upload to Drive so it's available for later
68
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
69
+
70
+ input_to_class = local_temp_path
71
+ print(input_to_class)
72
+ # pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
73
+ # pdf
74
+ if link.endswith(".pdf"):
75
+ # if file_path.is_file():
76
+ # link = saveFolder + "/" + name
77
+ # print("File exists.")
78
+ #p = pdf.PDF(local_temp_path, saveFolder)
79
+ print("inside pdf and input to class: ", input_to_class)
80
+ print("save folder in extract text: ", saveFolder)
81
+ p = pdf.PDF(input_to_class, saveFolder)
82
+ #p = pdf.PDF(link,saveFolder)
83
+ text = p.extractTextWithPDFReader()
84
+ print("text from pdf:")
85
+ print(text)
86
+ #text_exclude_table = p.extract_text_excluding_tables()
87
+ # worddoc
88
+ elif link.endswith(".doc") or link.endswith(".docx"):
89
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
90
+ d = wordDoc.wordDoc(input_to_class,saveFolder)
91
+ text = d.extractTextByPage()
92
+ # html
93
+ else:
94
+ if link.split(".")[-1].lower() not in "xlsx":
95
+ if "http" in link or "html" in link:
96
+ print("html link: ", link)
97
+ html = extractHTML.HTML("",link)
98
+ text = html.getListSection() # the text already clean
99
+ print("text html: ")
100
+ print(text)
101
+ # Cleanup: delete the local temp file
102
+ if name:
103
+ if os.path.exists(local_temp_path):
104
+ os.remove(local_temp_path)
105
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
106
+ print("done extract text")
107
+ except:
108
+ text = ""
109
+ return text
110
+
111
+ def extract_table(link,saveFolder):
112
+ try:
113
+ table = []
114
+ name = link.split("/")[-1]
115
+ #file_path = Path(saveFolder) / name
116
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
117
+ if os.path.exists(local_temp_path):
118
+ input_to_class = local_temp_path
119
+ print("exist")
120
+ else:
121
+ #input_to_class = link # Let the class handle downloading
122
+ # 1. Check if file exists in shared Google Drive folder
123
+ file_id = pipeline.find_drive_file(name, saveFolder)
124
+ if file_id:
125
+ print("📥 Downloading from Google Drive...")
126
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
127
+ else:
128
+ print("🌐 Downloading from web link...")
129
+ response = requests.get(link)
130
+ with open(local_temp_path, 'wb') as f:
131
+ f.write(response.content)
132
+ print("✅ Saved locally.")
133
+
134
+ # 2. Upload to Drive so it's available for later
135
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
136
+
137
+ input_to_class = local_temp_path
138
+ print(input_to_class)
139
+ #pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
140
+ # pdf
141
+ if link.endswith(".pdf"):
142
+ # if file_path.is_file():
143
+ # link = saveFolder + "/" + name
144
+ # print("File exists.")
145
+ #p = pdf.PDF(local_temp_path,saveFolder)
146
+ p = pdf.PDF(input_to_class,saveFolder)
147
+ table = p.extractTable()
148
+ # worddoc
149
+ elif link.endswith(".doc") or link.endswith(".docx"):
150
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
151
+ d = wordDoc.wordDoc(input_to_class,saveFolder)
152
+ table = d.extractTableAsList()
153
+ # excel
154
+ elif link.split(".")[-1].lower() in "xlsx":
155
+ # download excel file if it not downloaded yet
156
+ savePath = saveFolder +"/"+ link.split("/")[-1]
157
+ excelPath = download_excel_file(link, savePath)
158
+ try:
159
+ #xls = pd.ExcelFile(excelPath)
160
+ xls = pd.ExcelFile(local_temp_path)
161
+ table_list = []
162
+ for sheet_name in xls.sheet_names:
163
+ df = pd.read_excel(xls, sheet_name=sheet_name)
164
+ cleaned_table = df.fillna("").astype(str).values.tolist()
165
+ table_list.append(cleaned_table)
166
+ table = table_list
167
+ except Exception as e:
168
+ print("❌ Failed to extract tables from Excel:", e)
169
+ # html
170
+ elif "http" in link or "html" in link:
171
+ html = extractHTML.HTML("",link)
172
+ table = html.extractTable() # table is a list
173
+ table = clean_tables_format(table)
174
+ # Cleanup: delete the local temp file
175
+ if os.path.exists(local_temp_path):
176
+ os.remove(local_temp_path)
177
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
178
+ except:
179
+ table = []
180
+ return table
181
+
182
+ def clean_tables_format(tables):
183
+ """
184
+ Ensures all tables are in consistent format: List[List[List[str]]]
185
+ Cleans by:
186
+ - Removing empty strings and rows
187
+ - Converting all cells to strings
188
+ - Handling DataFrames and list-of-lists
189
+ """
190
+ cleaned = []
191
+ if tables:
192
+ for table in tables:
193
+ standardized = []
194
+
195
+ # Case 1: Pandas DataFrame
196
+ if isinstance(table, pd.DataFrame):
197
+ table = table.fillna("").astype(str).values.tolist()
198
+
199
+ # Case 2: List of Lists
200
+ if isinstance(table, list) and all(isinstance(row, list) for row in table):
201
+ for row in table:
202
+ filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
203
+ if filtered_row:
204
+ standardized.append(filtered_row)
205
+
206
+ if standardized:
207
+ cleaned.append(standardized)
208
+
209
+ return cleaned
210
+
211
+ import json
212
+ def normalize_text_for_comparison(s: str) -> str:
213
+ """
214
+ Normalizes text for robust comparison by:
215
+ 1. Converting to lowercase.
216
+ 2. Replacing all types of newlines with a single consistent newline (\n).
217
+ 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
218
+ 4. Stripping leading/trailing whitespace from the entire string.
219
+ """
220
+ s = s.lower()
221
+ s = s.replace('\r\n', '\n') # Handle Windows newlines
222
+ s = s.replace('\r', '\n') # Handle Mac classic newlines
223
+
224
+ # Replace sequences of whitespace (including multiple newlines) with a single space
225
+ # This might be too aggressive if you need to preserve paragraph breaks,
226
+ # but good for exact word-sequence matching.
227
+ s = re.sub(r'\s+', ' ', s)
228
+
229
+ return s.strip()
230
+ def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
231
+ """
232
+ Merge cleaned text and table into one string for LLM input.
233
+ - Avoids duplicating tables already in text
234
+ - Extracts only relevant rows from large tables
235
+ - Skips or saves oversized tables
236
+ """
237
+ import importlib
238
+ json = importlib.import_module("json")
239
+
240
+ def estimate_tokens(text_str):
241
+ try:
242
+ enc = tiktoken.get_encoding(tokenizer)
243
+ return len(enc.encode(text_str))
244
+ except:
245
+ return len(text_str) // 4 # Fallback estimate
246
+
247
+ def is_table_relevant(table, keywords, accession_id=None):
248
+ flat = " ".join(" ".join(row).lower() for row in table)
249
+ if accession_id and accession_id.lower() in flat:
250
+ return True
251
+ return any(kw.lower() in flat for kw in keywords)
252
+ preview, preview1 = "",""
253
+ llm_input = "## Document Text\n" + text.strip() + "\n"
254
+ clean_text = normalize_text_for_comparison(text)
255
+
256
+ if tables:
257
+ for idx, table in enumerate(tables):
258
+ keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
259
+ if accession_id: keywords += [accession_id.lower()]
260
+ if isolate: keywords += [isolate.lower()]
261
+ if is_table_relevant(table, keywords, accession_id):
262
+ if len(table) > 0:
263
+ for tab in table:
264
+ preview = " ".join(tab) if tab else ""
265
+ preview1 = "\n".join(tab) if tab else ""
266
+ clean_preview = normalize_text_for_comparison(preview)
267
+ clean_preview1 = normalize_text_for_comparison(preview1)
268
+ if clean_preview not in clean_text:
269
+ if clean_preview1 not in clean_text:
270
+ table_str = json.dumps([tab], indent=2)
271
+ llm_input += f"## Table {idx+1}\n{table_str}\n"
272
+ return llm_input.strip()
273
+
274
+ def preprocess_document(link, saveFolder, accession=None, isolate=None):
275
+ try:
276
+ text = extract_text(link, saveFolder)
277
+ print("text and link")
278
+ print(link)
279
+ print(text)
280
+ except: text = ""
281
+ try:
282
+ tables = extract_table(link, saveFolder)
283
+ except: tables = []
284
+ if accession: accession = accession
285
+ if isolate: isolate = isolate
286
+ try:
287
+ final_input = merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
288
+ except: final_input = ""
289
+ return text, tables, final_input
290
+
291
+ def extract_sentences(text):
292
+ sentences = re.split(r'(?<=[.!?])\s+', text)
293
+ return [s.strip() for s in sentences if s.strip()]
294
+
295
+ def is_irrelevant_number_sequence(text):
296
+ if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
297
+ return False
298
+ word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
299
+ number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
300
+ total_tokens = len(re.findall(r'\S+', text))
301
+ if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
302
+ return True
303
+ elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
304
+ return True
305
+ return False
306
+
307
+ def remove_isolated_single_digits(sentence):
308
+ tokens = sentence.split()
309
+ filtered_tokens = []
310
+ for token in tokens:
311
+ if token == '0' or token == '1':
312
+ pass
313
+ else:
314
+ filtered_tokens.append(token)
315
+ return ' '.join(filtered_tokens).strip()
316
+
317
+ def get_contextual_sentences_BFS(text_content, keyword, depth=2):
318
+ def extract_codes(sentence):
319
+ # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
320
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
321
+ sentences = extract_sentences(text_content)
322
+ relevant_sentences = set()
323
+ initial_keywords = set()
324
+
325
+ # Define a regex to capture codes like A1YU101 or KM1
326
+ # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
327
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
328
+
329
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
330
+ keyword_match = code_pattern.search(keyword)
331
+
332
+ keyword_prefix = None
333
+ keyword_num = None
334
+
335
+ if keyword_match:
336
+ keyword_prefix = keyword_match.group(1).lower()
337
+ keyword_num = int(keyword_match.group(2))
338
+
339
+ for sentence in sentences:
340
+ sentence_added = False
341
+
342
+ # 1. Check for exact match of the keyword
343
+ if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
344
+ relevant_sentences.add(sentence.strip())
345
+ initial_keywords.add(keyword.lower())
346
+ sentence_added = True
347
+
348
+ # 2. Check for range patterns (e.g., A1YU101-A1YU137)
349
+ # The range pattern should be broad enough to capture the full code string within the range.
350
+ range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
351
+ range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
352
+
353
+ for r_match in range_matches:
354
+ start_code_str = r_match.group(1)
355
+ end_code_str = r_match.group(2)
356
+
357
+ # CRITICAL FIX: Use code_pattern.search for start_match and end_match
358
+ start_match = code_pattern.search(start_code_str)
359
+ end_match = code_pattern.search(end_code_str)
360
+
361
+ if keyword_prefix and keyword_num is not None and start_match and end_match:
362
+ start_prefix = start_match.group(1).lower()
363
+ end_prefix = end_match.group(1).lower()
364
+ start_num = int(start_match.group(2))
365
+ end_num = int(end_match.group(2))
366
+
367
+ # Check if the keyword's prefix matches and its number is within the range
368
+ if keyword_prefix == start_prefix and \
369
+ keyword_prefix == end_prefix and \
370
+ start_num <= keyword_num <= end_num:
371
+ relevant_sentences.add(sentence.strip())
372
+ initial_keywords.add(start_code_str.lower())
373
+ initial_keywords.add(end_code_str.lower())
374
+ sentence_added = True
375
+ break # Only need to find one matching range per sentence
376
+
377
+ # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
378
+ # to initial_keywords to ensure graph traversal from related terms.
379
+ if sentence_added:
380
+ for word in extract_codes(sentence):
381
+ initial_keywords.add(word.lower())
382
+
383
+
384
+ # Build word_to_sentences mapping for all sentences
385
+ word_to_sentences = {}
386
+ for sent in sentences:
387
+ codes_in_sent = set(extract_codes(sent))
388
+ for code in codes_in_sent:
389
+ word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
390
+
391
+
392
+ # Build the graph
393
+ graph = {}
394
+ for sent in sentences:
395
+ codes = set(extract_codes(sent))
396
+ for word1 in codes:
397
+ word1_lower = word1.lower()
398
+ graph.setdefault(word1_lower, set())
399
+ for word2 in codes:
400
+ word2_lower = word2.lower()
401
+ if word1_lower != word2_lower:
402
+ graph[word1_lower].add(word2_lower)
403
+
404
+
405
+ # Perform BFS/graph traversal
406
+ queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
407
+ visited_words = set(initial_keywords)
408
+
409
+ while queue:
410
+ current_word, level = queue.pop(0)
411
+ if level >= depth:
412
+ continue
413
+
414
+ relevant_sentences.update(word_to_sentences.get(current_word, []))
415
+
416
+ for neighbor in graph.get(current_word, []):
417
+ if neighbor not in visited_words:
418
+ visited_words.add(neighbor)
419
+ queue.append((neighbor, level + 1))
420
+
421
+ final_sentences = set()
422
+ for sentence in relevant_sentences:
423
+ if not is_irrelevant_number_sequence(sentence):
424
+ processed_sentence = remove_isolated_single_digits(sentence)
425
+ if processed_sentence:
426
+ final_sentences.add(processed_sentence)
427
+
428
+ return "\n".join(sorted(list(final_sentences)))
429
+
430
+
431
+
432
+ def get_contextual_sentences_DFS(text_content, keyword, depth=2):
433
+ sentences = extract_sentences(text_content)
434
+
435
+ # Build word-to-sentences mapping
436
+ word_to_sentences = {}
437
+ for sent in sentences:
438
+ words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
439
+ for word in words_in_sent:
440
+ word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
441
+
442
+ # Function to extract codes in a sentence
443
+ def extract_codes(sentence):
444
+ # Only codes like 'KSK1', 'MG272794', not pure numbers
445
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
446
+
447
+ # DFS with priority based on distance to keyword and early stop if country found
448
+ def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
449
+ country = "unknown"
450
+ if current_depth > max_depth:
451
+ return country, False
452
+
453
+ if current_word not in word_to_sentences:
454
+ return country, False
455
+
456
+ for sentence in word_to_sentences[current_word]:
457
+ if sentence == parent_sentence:
458
+ continue # avoid reusing the same sentence
459
+
460
+ collected_sentences.add(sentence)
461
+
462
+ #print("current_word:", current_word)
463
+ small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
464
+ #print(small_sen)
465
+ country = model.get_country_from_text(small_sen)
466
+ #print("small context country:", country)
467
+ if country.lower() != "unknown":
468
+ return country, True
469
+ else:
470
+ country = model.get_country_from_text(sentence)
471
+ #print("full sentence country:", country)
472
+ if country.lower() != "unknown":
473
+ return country, True
474
+
475
+ codes_in_sentence = extract_codes(sentence)
476
+ idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
477
+ if idx is None:
478
+ continue
479
+
480
+ sorted_children = sorted(
481
+ [code for code in codes_in_sentence if code.lower() not in visited_words],
482
+ key=lambda x: (abs(codes_in_sentence.index(x) - idx),
483
+ 0 if codes_in_sentence.index(x) > idx else 1)
484
+ )
485
+
486
+ #print("sorted_children:", sorted_children)
487
+ for child in sorted_children:
488
+ child_lower = child.lower()
489
+ if child_lower not in visited_words:
490
+ visited_words.add(child_lower)
491
+ country, should_stop = dfs_traverse(
492
+ child_lower, current_depth + 1, max_depth,
493
+ visited_words, collected_sentences, parent_sentence=sentence
494
+ )
495
+ if should_stop:
496
+ return country, True
497
+
498
+ return country, False
499
+
500
+ # Begin DFS
501
+ collected_sentences = set()
502
+ visited_words = set([keyword.lower()])
503
+ country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
504
+
505
+ # Filter irrelevant sentences
506
+ final_sentences = set()
507
+ for sentence in collected_sentences:
508
+ if not is_irrelevant_number_sequence(sentence):
509
+ processed = remove_isolated_single_digits(sentence)
510
+ if processed:
511
+ final_sentences.add(processed)
512
+ if not final_sentences:
513
+ return country, text_content
514
+ return country, "\n".join(sorted(list(final_sentences)))
515
+
516
+ # Helper function for normalizing text for overlap comparison
517
+ def normalize_for_overlap(s: str) -> str:
518
+ s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
519
+ s = re.sub(r'\s+', ' ', s).strip()
520
+ return s
521
+
522
+ def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
523
+ if not text1: return text2
524
+ if not text2: return text1
525
+
526
+ # Case 1: text2 is fully contained in text1 or vice-versa
527
+ if text2 in text1:
528
+ return text1
529
+ if text1 in text2:
530
+ return text2
531
+
532
+ # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
533
+ # This is what your function was primarily designed for.
534
+ # It looks for the overlap at the "junction" of text1 and text2.
535
+
536
+ max_junction_overlap = 0
537
+ for i in range(min(len(text1), len(text2)), 0, -1):
538
+ suffix1 = text1[-i:]
539
+ prefix2 = text2[:i]
540
+ # Prioritize exact match, then normalized match
541
+ if suffix1 == prefix2:
542
+ max_junction_overlap = i
543
+ break
544
+ elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
545
+ max_junction_overlap = i
546
+ break # Take the first (longest) normalized match
547
+
548
+ if max_junction_overlap > 0:
549
+ merged_text = text1 + text2[max_junction_overlap:]
550
+ return re.sub(r'\s+', ' ', merged_text).strip()
551
+
552
+ # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
553
+ # This addresses your specific test case where the overlap is at the very beginning of both strings.
554
+ # This is often used when trying to deduplicate content that shares a common start.
555
+
556
+ longest_common_prefix_len = 0
557
+ min_len = min(len(text1), len(text2))
558
+ for i in range(min_len):
559
+ if text1[i] == text2[i]:
560
+ longest_common_prefix_len = i + 1
561
+ else:
562
+ break
563
+
564
+ # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
565
+ # AND the remaining parts are distinct, then apply this merge.
566
+ # This is a heuristic and might need fine-tuning.
567
+ if longest_common_prefix_len > 0 and \
568
+ text1[longest_common_prefix_len:].strip() and \
569
+ text2[longest_common_prefix_len:].strip():
570
+
571
+ # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
572
+ # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
573
+ # common prefix is "Hi, I am Vy."
574
+ # Remaining text1: " Nice to meet you."
575
+ # Remaining text2: " Goodbye Vy."
576
+ # So we merge common_prefix + remaining_text1 + remaining_text2
577
+
578
+ common_prefix_str = text1[:longest_common_prefix_len]
579
+ remainder_text1 = text1[longest_common_prefix_len:]
580
+ remainder_text2 = text2[longest_common_prefix_len:]
581
+
582
+ merged_text = common_prefix_str + remainder_text1 + remainder_text2
583
+ return re.sub(r'\s+', ' ', merged_text).strip()
584
+
585
+
586
+ # If neither specific overlap type is found, just concatenate
587
+ merged_text = text1 + text2
588
+ return re.sub(r'\s+', ' ', merged_text).strip()
589
+
590
+ from docx import Document
591
+ from pipeline import upload_file_to_drive
592
+ # def save_text_to_docx(text_content: str, file_path: str):
593
+ # """
594
+ # Saves a given text string into a .docx file.
595
+
596
+ # Args:
597
+ # text_content (str): The text string to save.
598
+ # file_path (str): The full path including the filename where the .docx file will be saved.
599
+ # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
600
+ # """
601
+ # try:
602
+ # document = Document()
603
+
604
+ # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
605
+ # for paragraph_text in text_content.split('\n'):
606
+ # document.add_paragraph(paragraph_text)
607
+
608
+ # document.save(file_path)
609
+ # print(f"Text successfully saved to '{file_path}'")
610
+ # except Exception as e:
611
+ # print(f"Error saving text to docx file: {e}")
612
+ # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
613
+ # """
614
+ # Saves a given text string into a .docx file locally, then uploads to Google Drive.
615
+
616
+ # Args:
617
+ # text_content (str): The text string to save.
618
+ # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
619
+ # drive_folder_id (str): Google Drive folder ID where to upload the file.
620
+ # """
621
+ # try:
622
+ # # Save to temporary local path first
623
+ # print("file name: ", filename)
624
+ # print("length text content: ", len(text_content))
625
+ # local_path = os.path.join(tempfile.gettempdir(), filename)
626
+ # document = Document()
627
+ # for paragraph_text in text_content.split('\n'):
628
+ # document.add_paragraph(paragraph_text)
629
+ # document.save(local_path)
630
+ # print(f"✅ Text saved locally to: {local_path}")
631
+
632
+ # # Upload to Drive
633
+ # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
634
+ # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
635
+
636
+ # except Exception as e:
637
+ # print(f"❌ Error saving or uploading DOCX: {e}")
638
+ def save_text_to_docx(text_content: str, full_local_path: str):
639
+ document = Document()
640
+ for paragraph_text in text_content.split('\n'):
641
+ document.add_paragraph(paragraph_text)
642
+ document.save(full_local_path)
643
+ print(f"✅ Saved DOCX locally: {full_local_path}")
644
+
645
+
646
+
647
+ '''2 scenerios:
648
+ - quick look then found then deepdive and directly get location then stop
649
+ - quick look then found then deepdive but not find location then hold the related words then
650
+ look another files iteratively for each related word and find location and stop'''
651
+ def extract_context(text, keyword, window=500):
652
+ # firstly try accession number
653
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
654
+
655
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
656
+ keyword_match = code_pattern.search(keyword)
657
+
658
+ keyword_prefix = None
659
+ keyword_num = None
660
+
661
+ if keyword_match:
662
+ keyword_prefix = keyword_match.group(1).lower()
663
+ keyword_num = int(keyword_match.group(2))
664
+ text = text.lower()
665
+ idx = text.find(keyword.lower())
666
+ if idx == -1:
667
+ if keyword_prefix:
668
+ idx = text.find(keyword_prefix)
669
+ if idx == -1:
670
+ return "Sample ID not found."
671
+ return text[max(0, idx-window): idx+window]
672
+ return text[max(0, idx-window): idx+window]
673
+ def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
674
+ cache = {}
675
+ country = "unknown"
676
+ output = ""
677
+ tem_output, small_output = "",""
678
+ keyword_appear = (False,"")
679
+ keywords = []
680
+ if isolate: keywords.append(isolate)
681
+ if accession: keywords.append(accession)
682
+ for f in filePaths:
683
+ # scenerio 1: direct location: truncate the context and then use qa model?
684
+ if keywords:
685
+ for keyword in keywords:
686
+ text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
687
+ if keyword in final_input:
688
+ context = extract_context(final_input, keyword)
689
+ # quick look if country already in context and if yes then return
690
+ country = model.get_country_from_text(context)
691
+ if country != "unknown":
692
+ return country, context, final_input
693
+ else:
694
+ country = model.get_country_from_text(final_input)
695
+ if country != "unknown":
696
+ return country, context, final_input
697
+ else: # might be cross-ref
698
+ keyword_appear = (True, f)
699
+ cache[f] = context
700
+ small_output = merge_texts_skipping_overlap(output, context) + "\n"
701
+ chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
702
+ countryBFS = model.get_country_from_text(chunkBFS)
703
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
704
+ output = merge_texts_skipping_overlap(output, final_input)
705
+ if countryDFS != "unknown" and countryBFS != "unknown":
706
+ if len(chunkDFS) <= len(chunkBFS):
707
+ return countryDFS, chunkDFS, output
708
+ else:
709
+ return countryBFS, chunkBFS, output
710
+ else:
711
+ if countryDFS != "unknown":
712
+ return countryDFS, chunkDFS, output
713
+ if countryBFS != "unknown":
714
+ return countryBFS, chunkBFS, output
715
+ else:
716
+ # scenerio 2:
717
+ '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
718
+ but if we look at file 1 first then maybe we can have lookup dict which country
719
+ such as Thailand as the key and its re'''
720
+ cache[f] = final_input
721
+ if keyword_appear[0] == True:
722
+ for c in cache:
723
+ if c!=keyword_appear[1]:
724
+ if cache[c].lower() not in output.lower():
725
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
726
+ chunkBFS = get_contextual_sentences_BFS(output, keyword)
727
+ countryBFS = model.get_country_from_text(chunkBFS)
728
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
729
+ if countryDFS != "unknown" and countryBFS != "unknown":
730
+ if len(chunkDFS) <= len(chunkBFS):
731
+ return countryDFS, chunkDFS, output
732
+ else:
733
+ return countryBFS, chunkBFS, output
734
+ else:
735
+ if countryDFS != "unknown":
736
+ return countryDFS, chunkDFS, output
737
+ if countryBFS != "unknown":
738
+ return countryBFS, chunkBFS, output
739
+ else:
740
+ if cache[f].lower() not in output.lower():
741
+ output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
742
+ if len(output) == 0 or keyword_appear[0]==False:
743
+ for c in cache:
744
+ if cache[c].lower() not in output.lower():
745
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
746
  return country, "", output
model.py CHANGED
The diff for this file is too large to render. See raw diff
 
mtdna_backend.py CHANGED
@@ -1,885 +1,897 @@
1
- import gradio as gr
2
- from collections import Counter
3
- import csv
4
- import os
5
- from functools import lru_cache
6
- #import app
7
- from mtdna_classifier import classify_sample_location
8
- import data_preprocess, model, pipeline
9
- import subprocess
10
- import json
11
- import pandas as pd
12
- import io
13
- import re
14
- import tempfile
15
- import gspread
16
- from oauth2client.service_account import ServiceAccountCredentials
17
- from io import StringIO
18
- import hashlib
19
- import threading
20
-
21
- # @lru_cache(maxsize=3600)
22
- # def classify_sample_location_cached(accession):
23
- # return classify_sample_location(accession)
24
-
25
- @lru_cache(maxsize=3600)
26
- def pipeline_classify_sample_location_cached(accession):
27
- return pipeline.pipeline_with_gemini([accession])
28
-
29
- # Count and suggest final location
30
- # def compute_final_suggested_location(rows):
31
- # candidates = [
32
- # row.get("Predicted Location", "").strip()
33
- # for row in rows
34
- # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
35
- # ] + [
36
- # row.get("Inferred Region", "").strip()
37
- # for row in rows
38
- # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
39
- # ]
40
-
41
- # if not candidates:
42
- # return Counter(), ("Unknown", 0)
43
- # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
44
- # tokens = []
45
- # for item in candidates:
46
- # # Split by comma, whitespace, and newlines
47
- # parts = re.split(r'[\s,]+', item)
48
- # tokens.extend(parts)
49
-
50
- # # Step 2: Clean and normalize tokens
51
- # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
52
-
53
- # # Step 3: Count
54
- # counts = Counter(tokens)
55
-
56
- # # Step 4: Get most common
57
- # top_location, count = counts.most_common(1)[0]
58
- # return counts, (top_location, count)
59
-
60
- # Store feedback (with required fields)
61
-
62
- def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
63
- if not answer1.strip() or not answer2.strip():
64
- return "⚠️ Please answer both questions before submitting."
65
-
66
- try:
67
- # ✅ Step: Load credentials from Hugging Face secret
68
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
69
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
70
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
71
-
72
- # Connect to Google Sheet
73
- client = gspread.authorize(creds)
74
- sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
75
-
76
- # Append feedback
77
- sheet.append_row([accession, answer1, answer2, contact])
78
- return "✅ Feedback submitted. Thank you!"
79
-
80
- except Exception as e:
81
- return f"❌ Error submitting feedback: {e}"
82
-
83
- # helper function to extract accessions
84
- def extract_accessions_from_input(file=None, raw_text=""):
85
- print(f"RAW TEXT RECEIVED: {raw_text}")
86
- accessions = []
87
- seen = set()
88
- if file:
89
- try:
90
- if file.name.endswith(".csv"):
91
- df = pd.read_csv(file)
92
- elif file.name.endswith(".xlsx"):
93
- df = pd.read_excel(file)
94
- else:
95
- return [], "Unsupported file format. Please upload CSV or Excel."
96
- for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
97
- if acc not in seen:
98
- accessions.append(acc)
99
- seen.add(acc)
100
- except Exception as e:
101
- return [], f"Failed to read file: {e}"
102
-
103
- if raw_text:
104
- text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
105
- for acc in text_ids:
106
- if acc not in seen:
107
- accessions.append(acc)
108
- seen.add(acc)
109
-
110
- return list(accessions), None
111
- # ✅ Add a new helper to backend: `filter_unprocessed_accessions()`
112
- def get_incomplete_accessions(file_path):
113
- df = pd.read_excel(file_path)
114
-
115
- incomplete_accessions = []
116
- for _, row in df.iterrows():
117
- sample_id = str(row.get("Sample ID", "")).strip()
118
-
119
- # Skip if no sample ID
120
- if not sample_id:
121
- continue
122
-
123
- # Drop the Sample ID and check if the rest is empty
124
- other_cols = row.drop(labels=["Sample ID"], errors="ignore")
125
- if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
126
- # Extract the accession number from the sample ID using regex
127
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
128
- if match:
129
- incomplete_accessions.append(match.group(0))
130
- print(len(incomplete_accessions))
131
- return incomplete_accessions
132
-
133
- # GOOGLE_SHEET_NAME = "known_samples"
134
- # USAGE_DRIVE_FILENAME = "user_usage_log.json"
135
-
136
- def summarize_results(accession):
137
- # try cache first
138
- cached = check_known_output(accession)
139
- if cached:
140
- print(f"✅ Using cached result for {accession}")
141
- return [[
142
- cached["Sample ID"] or "unknown",
143
- cached["Predicted Country"] or "unknown",
144
- cached["Country Explanation"] or "unknown",
145
- cached["Predicted Sample Type"] or "unknown",
146
- cached["Sample Type Explanation"] or "unknown",
147
- cached["Sources"] or "No Links",
148
- cached["Time cost"]
149
- ]]
150
- # only run when nothing in the cache
151
- try:
152
- print("try gemini pipeline: ",accession)
153
- outputs = pipeline_classify_sample_location_cached(accession)
154
- # outputs = {'KU131308': {'isolate':'BRU18',
155
- # 'country': {'brunei': ['ncbi',
156
- # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
157
- # 'sample_type': {'modern':
158
- # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
159
- # 'query_cost': 9.754999999999999e-05,
160
- # 'time_cost': '24.776 seconds',
161
- # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
162
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
163
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
164
- except Exception as e:
165
- return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
166
-
167
- if accession not in outputs:
168
- print("no accession in output ", accession)
169
- return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
170
-
171
- row_score = []
172
- rows = []
173
- save_rows = []
174
- for key in outputs:
175
- pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
176
- for section, results in outputs[key].items():
177
- if section == "country" or section =="sample_type":
178
- pred_output = []#"\n".join(list(results.keys()))
179
- output_explanation = ""
180
- for result, content in results.items():
181
- if len(result) == 0: result = "unknown"
182
- if len(content) == 0: output_explanation = "unknown"
183
- else:
184
- output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
185
- pred_output.append(result)
186
- pred_output = "\n".join(pred_output)
187
- if section == "country":
188
- pred_country, country_explanation = pred_output, output_explanation
189
- elif section == "sample_type":
190
- pred_sample, sample_explanation = pred_output, output_explanation
191
- if outputs[key]["isolate"].lower()!="unknown":
192
- label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
193
- else: label = key
194
- if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
195
- row = {
196
- "Sample ID": label or "unknown",
197
- "Predicted Country": pred_country or "unknown",
198
- "Country Explanation": country_explanation or "unknown",
199
- "Predicted Sample Type":pred_sample or "unknown",
200
- "Sample Type Explanation":sample_explanation or "unknown",
201
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
202
- "Time cost": outputs[key]["time_cost"]
203
- }
204
- #row_score.append(row)
205
- rows.append(list(row.values()))
206
-
207
- save_row = {
208
- "Sample ID": label or "unknown",
209
- "Predicted Country": pred_country or "unknown",
210
- "Country Explanation": country_explanation or "unknown",
211
- "Predicted Sample Type":pred_sample or "unknown",
212
- "Sample Type Explanation":sample_explanation or "unknown",
213
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
214
- "Query_cost": outputs[key]["query_cost"],
215
- "Time cost": outputs[key]["time_cost"]
216
- }
217
- #row_score.append(row)
218
- save_rows.append(list(save_row.values()))
219
-
220
- # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
221
- # summary_lines = [f"### 🧭 Location Summary:\n"]
222
- # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
223
- # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
224
- # summary = "\n".join(summary_lines)
225
-
226
- # save the new running sample to known excel file
227
- # try:
228
- # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
229
- # if os.path.exists(KNOWN_OUTPUT_PATH):
230
- # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
231
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
232
- # else:
233
- # df_combined = df_new
234
- # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
235
- # except Exception as e:
236
- # print(f"⚠️ Failed to save known output: {e}")
237
- # try:
238
- # df_new = pd.DataFrame(save_rows, columns=[
239
- # "Sample ID", "Predicted Country", "Country Explanation",
240
- # "Predicted Sample Type", "Sample Type Explanation",
241
- # "Sources", "Query_cost", "Time cost"
242
- # ])
243
-
244
- # # ✅ Google Sheets API setup
245
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
246
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
247
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
248
- # client = gspread.authorize(creds)
249
-
250
- # # ✅ Open the known_samples sheet
251
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
252
- # sheet = spreadsheet.sheet1
253
-
254
- # # ✅ Read old data
255
- # existing_data = sheet.get_all_values()
256
- # if existing_data:
257
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
258
- # else:
259
- # df_old = pd.DataFrame(columns=df_new.columns)
260
-
261
- # # ✅ Combine and remove duplicates
262
- # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
263
-
264
- # # ✅ Clear and write back
265
- # sheet.clear()
266
- # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
267
-
268
- # except Exception as e:
269
- # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
270
- try:
271
- # Prepare as DataFrame
272
- df_new = pd.DataFrame(save_rows, columns=[
273
- "Sample ID", "Predicted Country", "Country Explanation",
274
- "Predicted Sample Type", "Sample Type Explanation",
275
- "Sources", "Query_cost", "Time cost"
276
- ])
277
-
278
- # ✅ Setup Google Sheets
279
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
280
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
281
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
282
- client = gspread.authorize(creds)
283
- spreadsheet = client.open("known_samples")
284
- sheet = spreadsheet.sheet1
285
-
286
- # ✅ Read existing data
287
- existing_data = sheet.get_all_values()
288
- if existing_data:
289
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
290
- else:
291
- df_old = pd.DataFrame(columns=[
292
- "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
293
- "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
294
- "Query_cost", "Sample Type Explanation", "Sources", "Time cost"
295
- ])
296
-
297
- # ✅ Index by Sample ID
298
- df_old.set_index("Sample ID", inplace=True)
299
- df_new.set_index("Sample ID", inplace=True)
300
-
301
- # ✅ Update only matching fields
302
- update_columns = [
303
- "Predicted Country", "Predicted Sample Type", "Country Explanation",
304
- "Sample Type Explanation", "Sources", "Query_cost", "Time cost"
305
- ]
306
- for idx, row in df_new.iterrows():
307
- if idx not in df_old.index:
308
- df_old.loc[idx] = "" # new row, fill empty first
309
- for col in update_columns:
310
- if pd.notna(row[col]) and row[col] != "":
311
- df_old.at[idx, col] = row[col]
312
-
313
- # ✅ Reset and write back
314
- df_old.reset_index(inplace=True)
315
- sheet.clear()
316
- sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
317
- print("✅ Match results saved to known_samples.")
318
-
319
- except Exception as e:
320
- print(f"❌ Failed to update known_samples: {e}")
321
-
322
-
323
- return rows#, summary, labelAncient_Modern, explain_label
324
-
325
- # save the batch input in excel file
326
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
327
- # with pd.ExcelWriter(filename) as writer:
328
- # # Save table
329
- # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
330
- # df.to_excel(writer, sheet_name="Detailed Results", index=False)
331
- # try:
332
- # df_old = pd.read_excel(filename)
333
- # except:
334
- # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
335
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
336
- # # if os.path.exists(filename):
337
- # # df_old = pd.read_excel(filename)
338
- # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
339
- # # else:
340
- # # df_combined = df_new
341
- # df_combined.to_excel(filename, index=False)
342
- # # # Save summary
343
- # # summary_df = pd.DataFrame({"Summary": [summary_text]})
344
- # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
345
-
346
- # # # Save flag
347
- # # flag_df = pd.DataFrame({"Flag": [flag_text]})
348
- # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
349
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
350
- # df_new = pd.DataFrame(all_rows, columns=[
351
- # "Sample ID", "Predicted Country", "Country Explanation",
352
- # "Predicted Sample Type", "Sample Type Explanation",
353
- # "Sources", "Time cost"
354
- # ])
355
-
356
- # try:
357
- # if os.path.exists(filename):
358
- # df_old = pd.read_excel(filename)
359
- # else:
360
- # df_old = pd.DataFrame(columns=df_new.columns)
361
- # except Exception as e:
362
- # print(f"⚠️ Warning reading old Excel file: {e}")
363
- # df_old = pd.DataFrame(columns=df_new.columns)
364
-
365
- # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
366
- # df_old.set_index("Sample ID", inplace=True)
367
- # df_new.set_index("Sample ID", inplace=True)
368
-
369
- # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
370
-
371
- # df_combined = df_old.reset_index()
372
-
373
- # try:
374
- # df_combined.to_excel(filename, index=False)
375
- # except Exception as e:
376
- # print(f"❌ Failed to write Excel file {filename}: {e}")
377
- def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
378
- df_new = pd.DataFrame(all_rows, columns=[
379
- "Sample ID", "Predicted Country", "Country Explanation",
380
- "Predicted Sample Type", "Sample Type Explanation",
381
- "Sources", "Time cost"
382
- ])
383
-
384
- if is_resume and os.path.exists(filename):
385
- try:
386
- df_old = pd.read_excel(filename)
387
- except Exception as e:
388
- print(f"⚠️ Warning reading old Excel file: {e}")
389
- df_old = pd.DataFrame(columns=df_new.columns)
390
-
391
- # Set index and update existing rows
392
- df_old.set_index("Sample ID", inplace=True)
393
- df_new.set_index("Sample ID", inplace=True)
394
- df_old.update(df_new)
395
-
396
- df_combined = df_old.reset_index()
397
- else:
398
- # If not resuming or file doesn't exist, just use new rows
399
- df_combined = df_new
400
-
401
- try:
402
- df_combined.to_excel(filename, index=False)
403
- except Exception as e:
404
- print(f"❌ Failed to write Excel file {filename}: {e}")
405
-
406
-
407
- # save the batch input in JSON file
408
- def save_to_json(all_rows, summary_text, flag_text, filename):
409
- output_dict = {
410
- "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
411
- # "Summary_Text": summary_text,
412
- # "Ancient_Modern_Flag": flag_text
413
- }
414
-
415
- # If all_rows is a DataFrame, convert it
416
- if isinstance(all_rows, pd.DataFrame):
417
- output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
418
-
419
- with open(filename, "w") as external_file:
420
- json.dump(output_dict, external_file, indent=2)
421
-
422
- # save the batch input in Text file
423
- def save_to_txt(all_rows, summary_text, flag_text, filename):
424
- if isinstance(all_rows, pd.DataFrame):
425
- detailed_results = all_rows.to_dict(orient="records")
426
- output = ""
427
- #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
428
- output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
429
- for r in detailed_results:
430
- output += ",".join([str(v) for v in r.values()]) + "\n\n"
431
- with open(filename, "w") as f:
432
- f.write("=== Detailed Results ===\n")
433
- f.write(output + "\n")
434
-
435
- # f.write("\n=== Summary ===\n")
436
- # f.write(summary_text + "\n")
437
-
438
- # f.write("\n=== Ancient/Modern Flag ===\n")
439
- # f.write(flag_text + "\n")
440
-
441
- def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
442
- tmp_dir = tempfile.mkdtemp()
443
-
444
- #html_table = all_rows.value # assuming this is stored somewhere
445
-
446
- # Parse back to DataFrame
447
- #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
448
- all_rows = pd.read_html(StringIO(all_rows))[0]
449
- print(all_rows)
450
-
451
- if output_type == "Excel":
452
- file_path = f"{tmp_dir}/batch_output.xlsx"
453
- save_to_excel(all_rows, summary_text, flag_text, file_path)
454
- elif output_type == "JSON":
455
- file_path = f"{tmp_dir}/batch_output.json"
456
- save_to_json(all_rows, summary_text, flag_text, file_path)
457
- print("Done with JSON")
458
- elif output_type == "TXT":
459
- file_path = f"{tmp_dir}/batch_output.txt"
460
- save_to_txt(all_rows, summary_text, flag_text, file_path)
461
- else:
462
- return gr.update(visible=False) # invalid option
463
-
464
- return gr.update(value=file_path, visible=True)
465
- # save cost by checking the known outputs
466
-
467
- # def check_known_output(accession):
468
- # if not os.path.exists(KNOWN_OUTPUT_PATH):
469
- # return None
470
-
471
- # try:
472
- # df = pd.read_excel(KNOWN_OUTPUT_PATH)
473
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
474
- # if match:
475
- # accession = match.group(0)
476
-
477
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
478
- # if not matched.empty:
479
- # return matched.iloc[0].to_dict() # Return the cached row
480
- # except Exception as e:
481
- # print(f"⚠️ Failed to load known samples: {e}")
482
- # return None
483
-
484
- # def check_known_output(accession):
485
- # try:
486
- # # ✅ Load credentials from Hugging Face secret
487
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
488
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
489
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
490
- # client = gspread.authorize(creds)
491
-
492
- # # ✅ Open the known_samples sheet
493
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
494
- # sheet = spreadsheet.sheet1
495
-
496
- # # ✅ Read all rows
497
- # data = sheet.get_all_values()
498
- # if not data:
499
- # return None
500
-
501
- # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
502
-
503
- # # ✅ Normalize accession pattern
504
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
505
- # if match:
506
- # accession = match.group(0)
507
-
508
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
509
- # if not matched.empty:
510
- # return matched.iloc[0].to_dict()
511
-
512
- # except Exception as e:
513
- # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
514
- # return None
515
- def check_known_output(accession):
516
- try:
517
- # ✅ Load credentials from Hugging Face secret
518
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
519
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
520
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
521
- client = gspread.authorize(creds)
522
-
523
- spreadsheet = client.open("known_samples")
524
- sheet = spreadsheet.sheet1
525
-
526
- data = sheet.get_all_values()
527
- if not data:
528
- print("⚠️ Google Sheet 'known_samples' is empty.")
529
- return None
530
-
531
- df = pd.DataFrame(data[1:], columns=data[0])
532
- if "Sample ID" not in df.columns:
533
- print("❌ Column 'Sample ID' not found in Google Sheet.")
534
- return None
535
-
536
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
537
- if match:
538
- accession = match.group(0)
539
-
540
- matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
541
- if not matched.empty:
542
- #return matched.iloc[0].to_dict()
543
- row = matched.iloc[0]
544
- country = row.get("Predicted Country", "").strip().lower()
545
- sample_type = row.get("Predicted Sample Type", "").strip().lower()
546
-
547
- if country and country != "unknown" and sample_type and sample_type != "unknown":
548
- return row.to_dict()
549
- else:
550
- print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
551
- return None
552
- else:
553
- print(f"🔍 Accession {accession} not found in known_samples.")
554
- return None
555
-
556
- except Exception as e:
557
- import traceback
558
- print("❌ Exception occurred during check_known_output:")
559
- traceback.print_exc()
560
- return None
561
-
562
-
563
- def hash_user_id(user_input):
564
- return hashlib.sha256(user_input.encode()).hexdigest()
565
-
566
- # ✅ Load and save usage count
567
-
568
- # def load_user_usage():
569
- # if not os.path.exists(USER_USAGE_TRACK_FILE):
570
- # return {}
571
-
572
- # try:
573
- # with open(USER_USAGE_TRACK_FILE, "r") as f:
574
- # content = f.read().strip()
575
- # if not content:
576
- # return {} # file is empty
577
- # return json.loads(content)
578
- # except (json.JSONDecodeError, ValueError):
579
- # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
580
- # return {} # fallback to empty dict
581
- # def load_user_usage():
582
- # try:
583
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
584
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
585
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
586
- # client = gspread.authorize(creds)
587
-
588
- # sheet = client.open("user_usage_log").sheet1
589
- # data = sheet.get_all_records() # Assumes columns: email, usage_count
590
-
591
- # usage = {}
592
- # for row in data:
593
- # email = row.get("email", "").strip().lower()
594
- # count = int(row.get("usage_count", 0))
595
- # if email:
596
- # usage[email] = count
597
- # return usage
598
- # except Exception as e:
599
- # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
600
- # return {}
601
- # def load_user_usage():
602
- # try:
603
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
604
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
605
-
606
- # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
607
- # if not found:
608
- # return {} # not found, start fresh
609
-
610
- # #file_id = found[0]["id"]
611
- # file_id = found
612
- # content = pipeline.download_drive_file_content(file_id)
613
- # return json.loads(content.strip()) if content.strip() else {}
614
-
615
- # except Exception as e:
616
- # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
617
- # return {}
618
- def load_user_usage():
619
- try:
620
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
621
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
622
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
623
- client = gspread.authorize(creds)
624
-
625
- sheet = client.open("user_usage_log").sheet1
626
- data = sheet.get_all_values()
627
- print("data: ", data)
628
- print("🧪 Raw header row from sheet:", data[0])
629
- print("🧪 Character codes in each header:")
630
- for h in data[0]:
631
- print([ord(c) for c in h])
632
-
633
- if not data or len(data) < 2:
634
- print("⚠️ Sheet is empty or missing rows.")
635
- return {}
636
-
637
- headers = [h.strip().lower() for h in data[0]]
638
- if "email" not in headers or "usage_count" not in headers:
639
- print(" Header format incorrect. Must have 'email' and 'usage_count'.")
640
- return {}
641
-
642
- df = pd.DataFrame(data[1:], columns=headers)
643
-
644
- usage = {}
645
- for _, row in df.iterrows():
646
- email = row.get("email", "").strip().lower()
647
- try:
648
- #count = int(row.get("usage_count", 0))
649
- try:
650
- count = int(float(row.get("usage_count", 0)))
651
- except Exception:
652
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
653
- count = 0
654
-
655
- if email:
656
- usage[email] = count
657
- except ValueError:
658
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
659
- return usage
660
-
661
- except Exception as e:
662
- print(f" Error in load_user_usage: {e}")
663
- return {}
664
-
665
-
666
-
667
- # def save_user_usage(usage):
668
- # with open(USER_USAGE_TRACK_FILE, "w") as f:
669
- # json.dump(usage, f, indent=2)
670
-
671
- # def save_user_usage(usage_dict):
672
- # try:
673
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
674
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
675
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
676
- # client = gspread.authorize(creds)
677
-
678
- # sheet = client.open("user_usage_log").sheet1
679
- # sheet.clear() # clear old contents first
680
-
681
- # # Write header + rows
682
- # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
683
- # sheet.update(rows)
684
- # except Exception as e:
685
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
686
- # def save_user_usage(usage_dict):
687
- # try:
688
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
689
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
690
-
691
- # import tempfile
692
- # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
693
- # print("💾 Saving this usage dict:", usage_dict)
694
- # with open(tmp_path, "w") as f:
695
- # json.dump(usage_dict, f, indent=2)
696
-
697
- # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
698
-
699
- # except Exception as e:
700
- # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
701
- # def save_user_usage(usage_dict):
702
- # try:
703
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
704
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
705
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
706
- # client = gspread.authorize(creds)
707
-
708
- # spreadsheet = client.open("user_usage_log")
709
- # sheet = spreadsheet.sheet1
710
-
711
- # # Step 1: Convert new usage to DataFrame
712
- # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
713
- # df_new["email"] = df_new["email"].str.strip().str.lower()
714
-
715
- # # Step 2: Load existing data
716
- # existing_data = sheet.get_all_values()
717
- # print("🧪 Sheet existing_data:", existing_data)
718
-
719
- # # Try to load old data
720
- # if existing_data and len(existing_data[0]) >= 1:
721
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
722
-
723
- # # Fix missing columns
724
- # if "email" not in df_old.columns:
725
- # df_old["email"] = ""
726
- # if "usage_count" not in df_old.columns:
727
- # df_old["usage_count"] = 0
728
-
729
- # df_old["email"] = df_old["email"].str.strip().str.lower()
730
- # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
731
- # else:
732
- # df_old = pd.DataFrame(columns=["email", "usage_count"])
733
-
734
- # # Step 3: Merge
735
- # df_combined = pd.concat([df_old, df_new], ignore_index=True)
736
- # df_combined = df_combined.groupby("email", as_index=False).sum()
737
-
738
- # # Step 4: Write back
739
- # sheet.clear()
740
- # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
741
- # print("✅ Saved user usage to user_usage_log sheet.")
742
-
743
- # except Exception as e:
744
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
745
- def save_user_usage(usage_dict):
746
- try:
747
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
748
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
749
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
750
- client = gspread.authorize(creds)
751
-
752
- spreadsheet = client.open("user_usage_log")
753
- sheet = spreadsheet.sheet1
754
-
755
- # Build new df
756
- df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
757
- df_new["email"] = df_new["email"].str.strip().str.lower()
758
- df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
759
-
760
- # Read existing data
761
- existing_data = sheet.get_all_values()
762
- if existing_data and len(existing_data[0]) >= 2:
763
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
764
- df_old["email"] = df_old["email"].str.strip().str.lower()
765
- df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
766
- else:
767
- df_old = pd.DataFrame(columns=["email", "usage_count"])
768
-
769
- # ✅ Overwrite specific emails only
770
- df_old = df_old.set_index("email")
771
- for email, count in usage_dict.items():
772
- email = email.strip().lower()
773
- df_old.loc[email, "usage_count"] = count
774
- df_old = df_old.reset_index()
775
-
776
- # Save
777
- sheet.clear()
778
- sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
779
- print("Saved user usage to user_usage_log sheet.")
780
-
781
- except Exception as e:
782
- print(f"❌ Failed to save user usage to Google Sheets: {e}")
783
-
784
-
785
-
786
-
787
- # def increment_usage(user_id, num_samples=1):
788
- # usage = load_user_usage()
789
- # if user_id not in usage:
790
- # usage[user_id] = 0
791
- # usage[user_id] += num_samples
792
- # save_user_usage(usage)
793
- # return usage[user_id]
794
- # def increment_usage(email: str, count: int):
795
- # usage = load_user_usage()
796
- # email_key = email.strip().lower()
797
- # usage[email_key] = usage.get(email_key, 0) + count
798
- # save_user_usage(usage)
799
- # return usage[email_key]
800
- def increment_usage(email: str, count: int = 1):
801
- usage = load_user_usage()
802
- email_key = email.strip().lower()
803
- #usage[email_key] = usage.get(email_key, 0) + count
804
- current = usage.get(email_key, 0)
805
- new_value = current + count
806
- usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
807
- print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
808
- save_user_usage(usage)
809
- return usage[email_key]
810
-
811
-
812
- # run the batch
813
- def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
814
- stop_flag=None, output_file_path=None,
815
- limited_acc=50, yield_callback=None):
816
- if user_email:
817
- limited_acc += 10
818
- accessions, error = extract_accessions_from_input(file, raw_text)
819
- if error:
820
- #return [], "", "", f"Error: {error}"
821
- return [], f"Error: {error}", 0, "", ""
822
- if resume_file:
823
- accessions = get_incomplete_accessions(resume_file)
824
- tmp_dir = tempfile.mkdtemp()
825
- if not output_file_path:
826
- if resume_file:
827
- output_file_path = os.path.join(tmp_dir, resume_file)
828
- else:
829
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
830
-
831
- all_rows = []
832
- # all_summaries = []
833
- # all_flags = []
834
- progress_lines = []
835
- warning = ""
836
- if len(accessions) > limited_acc:
837
- accessions = accessions[:limited_acc]
838
- warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
839
- for i, acc in enumerate(accessions):
840
- if stop_flag and stop_flag.value:
841
- line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
842
- progress_lines.append(line)
843
- if yield_callback:
844
- yield_callback(line)
845
- print("🛑 User requested stop.")
846
- break
847
- print(f"[{i+1}/{len(accessions)}] Processing {acc}")
848
- try:
849
- # rows, summary, label, explain = summarize_results(acc)
850
- rows = summarize_results(acc)
851
- all_rows.extend(rows)
852
- # all_summaries.append(f"**{acc}**\n{summary}")
853
- # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
854
- #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
855
- save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
856
- line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
857
- progress_lines.append(line)
858
- if yield_callback:
859
- yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
860
- except Exception as e:
861
- print(f"❌ Failed to process {acc}: {e}")
862
- continue
863
- #all_summaries.append(f"**{acc}**: Failed - {e}")
864
- #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
865
- limited_acc -= 1
866
- """for row in all_rows:
867
- source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
868
-
869
- if source_column.startswith("http"): # Check if the source is a URL
870
- # Wrap it with HTML anchor tags to make it clickable
871
- row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
872
- if not warning:
873
- warning = f"You only have {limited_acc} left"
874
- if user_email.strip():
875
- user_hash = hash_user_id(user_email)
876
- total_queries = increment_usage(user_hash, len(all_rows))
877
- else:
878
- total_queries = 0
879
- yield_callback(" Finished!")
880
-
881
- # summary_text = "\n\n---\n\n".join(all_summaries)
882
- # flag_text = "\n\n---\n\n".join(all_flags)
883
- #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
884
- #return all_rows, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
885
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
 
1
+ import gradio as gr
2
+ from collections import Counter
3
+ import csv
4
+ import os
5
+ from functools import lru_cache
6
+ #import app
7
+ from mtdna_classifier import classify_sample_location
8
+ import data_preprocess, model, pipeline
9
+ import subprocess
10
+ import json
11
+ import pandas as pd
12
+ import io
13
+ import re
14
+ import tempfile
15
+ import gspread
16
+ from oauth2client.service_account import ServiceAccountCredentials
17
+ from io import StringIO
18
+ import hashlib
19
+ import threading
20
+
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ @lru_cache(maxsize=3600)
26
+ def pipeline_classify_sample_location_cached(accession):
27
+ print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
+ return pipeline.pipeline_with_gemini([accession])
29
+
30
+ # Count and suggest final location
31
+ # def compute_final_suggested_location(rows):
32
+ # candidates = [
33
+ # row.get("Predicted Location", "").strip()
34
+ # for row in rows
35
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
36
+ # ] + [
37
+ # row.get("Inferred Region", "").strip()
38
+ # for row in rows
39
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
40
+ # ]
41
+
42
+ # if not candidates:
43
+ # return Counter(), ("Unknown", 0)
44
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
45
+ # tokens = []
46
+ # for item in candidates:
47
+ # # Split by comma, whitespace, and newlines
48
+ # parts = re.split(r'[\s,]+', item)
49
+ # tokens.extend(parts)
50
+
51
+ # # Step 2: Clean and normalize tokens
52
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
53
+
54
+ # # Step 3: Count
55
+ # counts = Counter(tokens)
56
+
57
+ # # Step 4: Get most common
58
+ # top_location, count = counts.most_common(1)[0]
59
+ # return counts, (top_location, count)
60
+
61
+ # Store feedback (with required fields)
62
+
63
+ def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
64
+ if not answer1.strip() or not answer2.strip():
65
+ return "⚠️ Please answer both questions before submitting."
66
+
67
+ try:
68
+ # Step: Load credentials from Hugging Face secret
69
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
70
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
71
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
72
+
73
+ # Connect to Google Sheet
74
+ client = gspread.authorize(creds)
75
+ sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
76
+
77
+ # Append feedback
78
+ sheet.append_row([accession, answer1, answer2, contact])
79
+ return "✅ Feedback submitted. Thank you!"
80
+
81
+ except Exception as e:
82
+ return f"❌ Error submitting feedback: {e}"
83
+
84
+ # helper function to extract accessions
85
+ def extract_accessions_from_input(file=None, raw_text=""):
86
+ print(f"RAW TEXT RECEIVED: {raw_text}")
87
+ accessions = []
88
+ seen = set()
89
+ if file:
90
+ try:
91
+ if file.name.endswith(".csv"):
92
+ df = pd.read_csv(file)
93
+ elif file.name.endswith(".xlsx"):
94
+ df = pd.read_excel(file)
95
+ else:
96
+ return [], "Unsupported file format. Please upload CSV or Excel."
97
+ for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
98
+ if acc not in seen:
99
+ accessions.append(acc)
100
+ seen.add(acc)
101
+ except Exception as e:
102
+ return [], f"Failed to read file: {e}"
103
+
104
+ if raw_text:
105
+ text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
106
+ for acc in text_ids:
107
+ if acc not in seen:
108
+ accessions.append(acc)
109
+ seen.add(acc)
110
+
111
+ return list(accessions), None
112
+ # ✅ Add a new helper to backend: `filter_unprocessed_accessions()`
113
+ def get_incomplete_accessions(file_path):
114
+ df = pd.read_excel(file_path)
115
+
116
+ incomplete_accessions = []
117
+ for _, row in df.iterrows():
118
+ sample_id = str(row.get("Sample ID", "")).strip()
119
+
120
+ # Skip if no sample ID
121
+ if not sample_id:
122
+ continue
123
+
124
+ # Drop the Sample ID and check if the rest is empty
125
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
126
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
127
+ # Extract the accession number from the sample ID using regex
128
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
129
+ if match:
130
+ incomplete_accessions.append(match.group(0))
131
+ print(len(incomplete_accessions))
132
+ return incomplete_accessions
133
+
134
+ # GOOGLE_SHEET_NAME = "known_samples"
135
+ # USAGE_DRIVE_FILENAME = "user_usage_log.json"
136
+
137
+ def summarize_results(accession):
138
+ # try cache first
139
+ cached = check_known_output(accession)
140
+ if cached:
141
+ print(f"✅ Using cached result for {accession}")
142
+ return [[
143
+ cached["Sample ID"] or "unknown",
144
+ cached["Predicted Country"] or "unknown",
145
+ cached["Country Explanation"] or "unknown",
146
+ cached["Predicted Sample Type"] or "unknown",
147
+ cached["Sample Type Explanation"] or "unknown",
148
+ cached["Sources"] or "No Links",
149
+ cached["Time cost"]
150
+ ]]
151
+ # only run when nothing in the cache
152
+ try:
153
+ print("try gemini pipeline: ",accession)
154
+ outputs = pipeline_classify_sample_location_cached(accession)
155
+ # outputs = {'KU131308': {'isolate':'BRU18',
156
+ # 'country': {'brunei': ['ncbi',
157
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
158
+ # 'sample_type': {'modern':
159
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
160
+ # 'query_cost': 9.754999999999999e-05,
161
+ # 'time_cost': '24.776 seconds',
162
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
163
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
164
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
165
+ except Exception as e:
166
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
167
+
168
+ if accession not in outputs:
169
+ print("no accession in output ", accession)
170
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
171
+
172
+ row_score = []
173
+ rows = []
174
+ save_rows = []
175
+ for key in outputs:
176
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
177
+ for section, results in outputs[key].items():
178
+ if section == "country" or section =="sample_type":
179
+ pred_output = []#"\n".join(list(results.keys()))
180
+ output_explanation = ""
181
+ for result, content in results.items():
182
+ if len(result) == 0: result = "unknown"
183
+ if len(content) == 0: output_explanation = "unknown"
184
+ else:
185
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
186
+ pred_output.append(result)
187
+ pred_output = "\n".join(pred_output)
188
+ if section == "country":
189
+ pred_country, country_explanation = pred_output, output_explanation
190
+ elif section == "sample_type":
191
+ pred_sample, sample_explanation = pred_output, output_explanation
192
+ if outputs[key]["isolate"].lower()!="unknown":
193
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
194
+ else: label = key
195
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
196
+ row = {
197
+ "Sample ID": label or "unknown",
198
+ "Predicted Country": pred_country or "unknown",
199
+ "Country Explanation": country_explanation or "unknown",
200
+ "Predicted Sample Type":pred_sample or "unknown",
201
+ "Sample Type Explanation":sample_explanation or "unknown",
202
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
203
+ "Time cost": outputs[key]["time_cost"]
204
+ }
205
+ #row_score.append(row)
206
+ rows.append(list(row.values()))
207
+
208
+ save_row = {
209
+ "Sample ID": label or "unknown",
210
+ "Predicted Country": pred_country or "unknown",
211
+ "Country Explanation": country_explanation or "unknown",
212
+ "Predicted Sample Type":pred_sample or "unknown",
213
+ "Sample Type Explanation":sample_explanation or "unknown",
214
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
215
+ "Query_cost": outputs[key]["query_cost"],
216
+ "Time cost": outputs[key]["time_cost"]
217
+ }
218
+ #row_score.append(row)
219
+ save_rows.append(list(save_row.values()))
220
+
221
+ # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
222
+ # summary_lines = [f"### 🧭 Location Summary:\n"]
223
+ # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
224
+ # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
225
+ # summary = "\n".join(summary_lines)
226
+
227
+ # save the new running sample to known excel file
228
+ # try:
229
+ # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
230
+ # if os.path.exists(KNOWN_OUTPUT_PATH):
231
+ # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
232
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
233
+ # else:
234
+ # df_combined = df_new
235
+ # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
236
+ # except Exception as e:
237
+ # print(f"⚠️ Failed to save known output: {e}")
238
+ # try:
239
+ # df_new = pd.DataFrame(save_rows, columns=[
240
+ # "Sample ID", "Predicted Country", "Country Explanation",
241
+ # "Predicted Sample Type", "Sample Type Explanation",
242
+ # "Sources", "Query_cost", "Time cost"
243
+ # ])
244
+
245
+ # # Google Sheets API setup
246
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
247
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
248
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
249
+ # client = gspread.authorize(creds)
250
+
251
+ # # Open the known_samples sheet
252
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
253
+ # sheet = spreadsheet.sheet1
254
+
255
+ # # Read old data
256
+ # existing_data = sheet.get_all_values()
257
+ # if existing_data:
258
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
259
+ # else:
260
+ # df_old = pd.DataFrame(columns=df_new.columns)
261
+
262
+ # # Combine and remove duplicates
263
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
264
+
265
+ # # ✅ Clear and write back
266
+ # sheet.clear()
267
+ # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
268
+
269
+ # except Exception as e:
270
+ # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
271
+ try:
272
+ # Prepare as DataFrame
273
+ df_new = pd.DataFrame(save_rows, columns=[
274
+ "Sample ID", "Predicted Country", "Country Explanation",
275
+ "Predicted Sample Type", "Sample Type Explanation",
276
+ "Sources", "Query_cost", "Time cost"
277
+ ])
278
+
279
+ # Setup Google Sheets
280
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
281
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
282
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
283
+ client = gspread.authorize(creds)
284
+ spreadsheet = client.open("known_samples")
285
+ sheet = spreadsheet.sheet1
286
+
287
+ # Read existing data
288
+ existing_data = sheet.get_all_values()
289
+ if existing_data:
290
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
291
+ else:
292
+ df_old = pd.DataFrame(columns=[
293
+ "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
294
+ "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
295
+ "Query_cost", "Sample Type Explanation", "Sources", "Time cost"
296
+ ])
297
+
298
+ # ✅ Index by Sample ID
299
+ df_old.set_index("Sample ID", inplace=True)
300
+ df_new.set_index("Sample ID", inplace=True)
301
+
302
+ # Update only matching fields
303
+ update_columns = [
304
+ "Predicted Country", "Predicted Sample Type", "Country Explanation",
305
+ "Sample Type Explanation", "Sources", "Query_cost", "Time cost"
306
+ ]
307
+ for idx, row in df_new.iterrows():
308
+ if idx not in df_old.index:
309
+ df_old.loc[idx] = "" # new row, fill empty first
310
+ for col in update_columns:
311
+ if pd.notna(row[col]) and row[col] != "":
312
+ df_old.at[idx, col] = row[col]
313
+
314
+ # ✅ Reset and write back
315
+ df_old.reset_index(inplace=True)
316
+ sheet.clear()
317
+ sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
318
+ print("✅ Match results saved to known_samples.")
319
+
320
+ except Exception as e:
321
+ print(f"❌ Failed to update known_samples: {e}")
322
+
323
+
324
+ return rows#, summary, labelAncient_Modern, explain_label
325
+
326
+ # save the batch input in excel file
327
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
328
+ # with pd.ExcelWriter(filename) as writer:
329
+ # # Save table
330
+ # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
331
+ # df.to_excel(writer, sheet_name="Detailed Results", index=False)
332
+ # try:
333
+ # df_old = pd.read_excel(filename)
334
+ # except:
335
+ # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
336
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
337
+ # # if os.path.exists(filename):
338
+ # # df_old = pd.read_excel(filename)
339
+ # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
340
+ # # else:
341
+ # # df_combined = df_new
342
+ # df_combined.to_excel(filename, index=False)
343
+ # # # Save summary
344
+ # # summary_df = pd.DataFrame({"Summary": [summary_text]})
345
+ # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
346
+
347
+ # # # Save flag
348
+ # # flag_df = pd.DataFrame({"Flag": [flag_text]})
349
+ # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
350
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
351
+ # df_new = pd.DataFrame(all_rows, columns=[
352
+ # "Sample ID", "Predicted Country", "Country Explanation",
353
+ # "Predicted Sample Type", "Sample Type Explanation",
354
+ # "Sources", "Time cost"
355
+ # ])
356
+
357
+ # try:
358
+ # if os.path.exists(filename):
359
+ # df_old = pd.read_excel(filename)
360
+ # else:
361
+ # df_old = pd.DataFrame(columns=df_new.columns)
362
+ # except Exception as e:
363
+ # print(f"⚠️ Warning reading old Excel file: {e}")
364
+ # df_old = pd.DataFrame(columns=df_new.columns)
365
+
366
+ # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
367
+ # df_old.set_index("Sample ID", inplace=True)
368
+ # df_new.set_index("Sample ID", inplace=True)
369
+
370
+ # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
371
+
372
+ # df_combined = df_old.reset_index()
373
+
374
+ # try:
375
+ # df_combined.to_excel(filename, index=False)
376
+ # except Exception as e:
377
+ # print(f"❌ Failed to write Excel file {filename}: {e}")
378
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
379
+ df_new = pd.DataFrame(all_rows, columns=[
380
+ "Sample ID", "Predicted Country", "Country Explanation",
381
+ "Predicted Sample Type", "Sample Type Explanation",
382
+ "Sources", "Time cost"
383
+ ])
384
+
385
+ if is_resume and os.path.exists(filename):
386
+ try:
387
+ df_old = pd.read_excel(filename)
388
+ except Exception as e:
389
+ print(f"⚠️ Warning reading old Excel file: {e}")
390
+ df_old = pd.DataFrame(columns=df_new.columns)
391
+
392
+ # Set index and update existing rows
393
+ df_old.set_index("Sample ID", inplace=True)
394
+ df_new.set_index("Sample ID", inplace=True)
395
+ df_old.update(df_new)
396
+
397
+ df_combined = df_old.reset_index()
398
+ else:
399
+ # If not resuming or file doesn't exist, just use new rows
400
+ df_combined = df_new
401
+
402
+ try:
403
+ df_combined.to_excel(filename, index=False)
404
+ except Exception as e:
405
+ print(f"❌ Failed to write Excel file {filename}: {e}")
406
+
407
+
408
+ # save the batch input in JSON file
409
+ def save_to_json(all_rows, summary_text, flag_text, filename):
410
+ output_dict = {
411
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
412
+ # "Summary_Text": summary_text,
413
+ # "Ancient_Modern_Flag": flag_text
414
+ }
415
+
416
+ # If all_rows is a DataFrame, convert it
417
+ if isinstance(all_rows, pd.DataFrame):
418
+ output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
419
+
420
+ with open(filename, "w") as external_file:
421
+ json.dump(output_dict, external_file, indent=2)
422
+
423
+ # save the batch input in Text file
424
+ def save_to_txt(all_rows, summary_text, flag_text, filename):
425
+ if isinstance(all_rows, pd.DataFrame):
426
+ detailed_results = all_rows.to_dict(orient="records")
427
+ output = ""
428
+ #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
429
+ output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
430
+ for r in detailed_results:
431
+ output += ",".join([str(v) for v in r.values()]) + "\n\n"
432
+ with open(filename, "w") as f:
433
+ f.write("=== Detailed Results ===\n")
434
+ f.write(output + "\n")
435
+
436
+ # f.write("\n=== Summary ===\n")
437
+ # f.write(summary_text + "\n")
438
+
439
+ # f.write("\n=== Ancient/Modern Flag ===\n")
440
+ # f.write(flag_text + "\n")
441
+
442
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
443
+ tmp_dir = tempfile.mkdtemp()
444
+
445
+ #html_table = all_rows.value # assuming this is stored somewhere
446
+
447
+ # Parse back to DataFrame
448
+ #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
449
+ all_rows = pd.read_html(StringIO(all_rows))[0]
450
+ print(all_rows)
451
+
452
+ if output_type == "Excel":
453
+ file_path = f"{tmp_dir}/batch_output.xlsx"
454
+ save_to_excel(all_rows, summary_text, flag_text, file_path)
455
+ elif output_type == "JSON":
456
+ file_path = f"{tmp_dir}/batch_output.json"
457
+ save_to_json(all_rows, summary_text, flag_text, file_path)
458
+ print("Done with JSON")
459
+ elif output_type == "TXT":
460
+ file_path = f"{tmp_dir}/batch_output.txt"
461
+ save_to_txt(all_rows, summary_text, flag_text, file_path)
462
+ else:
463
+ return gr.update(visible=False) # invalid option
464
+
465
+ return gr.update(value=file_path, visible=True)
466
+ # save cost by checking the known outputs
467
+
468
+ # def check_known_output(accession):
469
+ # if not os.path.exists(KNOWN_OUTPUT_PATH):
470
+ # return None
471
+
472
+ # try:
473
+ # df = pd.read_excel(KNOWN_OUTPUT_PATH)
474
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
475
+ # if match:
476
+ # accession = match.group(0)
477
+
478
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
479
+ # if not matched.empty:
480
+ # return matched.iloc[0].to_dict() # Return the cached row
481
+ # except Exception as e:
482
+ # print(f"⚠️ Failed to load known samples: {e}")
483
+ # return None
484
+
485
+ # def check_known_output(accession):
486
+ # try:
487
+ # # Load credentials from Hugging Face secret
488
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
489
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
490
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
491
+ # client = gspread.authorize(creds)
492
+
493
+ # # Open the known_samples sheet
494
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
495
+ # sheet = spreadsheet.sheet1
496
+
497
+ # # Read all rows
498
+ # data = sheet.get_all_values()
499
+ # if not data:
500
+ # return None
501
+
502
+ # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
503
+
504
+ # # Normalize accession pattern
505
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
506
+ # if match:
507
+ # accession = match.group(0)
508
+
509
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
510
+ # if not matched.empty:
511
+ # return matched.iloc[0].to_dict()
512
+
513
+ # except Exception as e:
514
+ # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
515
+ # return None
516
+ def check_known_output(accession):
517
+ try:
518
+ # Load credentials from Hugging Face secret
519
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
520
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
521
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
522
+ client = gspread.authorize(creds)
523
+
524
+ spreadsheet = client.open("known_samples")
525
+ sheet = spreadsheet.sheet1
526
+
527
+ data = sheet.get_all_values()
528
+ if not data:
529
+ print("⚠️ Google Sheet 'known_samples' is empty.")
530
+ return None
531
+
532
+ df = pd.DataFrame(data[1:], columns=data[0])
533
+ if "Sample ID" not in df.columns:
534
+ print("❌ Column 'Sample ID' not found in Google Sheet.")
535
+ return None
536
+
537
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
538
+ if match:
539
+ accession = match.group(0)
540
+
541
+ matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
542
+ if not matched.empty:
543
+ #return matched.iloc[0].to_dict()
544
+ row = matched.iloc[0]
545
+ country = row.get("Predicted Country", "").strip().lower()
546
+ sample_type = row.get("Predicted Sample Type", "").strip().lower()
547
+
548
+ if country and country != "unknown" and sample_type and sample_type != "unknown":
549
+ return row.to_dict()
550
+ else:
551
+ print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
552
+ return None
553
+ else:
554
+ print(f"🔍 Accession {accession} not found in known_samples.")
555
+ return None
556
+
557
+ except Exception as e:
558
+ import traceback
559
+ print("❌ Exception occurred during check_known_output:")
560
+ traceback.print_exc()
561
+ return None
562
+
563
+
564
+ def hash_user_id(user_input):
565
+ return hashlib.sha256(user_input.encode()).hexdigest()
566
+
567
+ # ✅ Load and save usage count
568
+
569
+ # def load_user_usage():
570
+ # if not os.path.exists(USER_USAGE_TRACK_FILE):
571
+ # return {}
572
+
573
+ # try:
574
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
575
+ # content = f.read().strip()
576
+ # if not content:
577
+ # return {} # file is empty
578
+ # return json.loads(content)
579
+ # except (json.JSONDecodeError, ValueError):
580
+ # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
581
+ # return {} # fallback to empty dict
582
+ # def load_user_usage():
583
+ # try:
584
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
585
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
586
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
587
+ # client = gspread.authorize(creds)
588
+
589
+ # sheet = client.open("user_usage_log").sheet1
590
+ # data = sheet.get_all_records() # Assumes columns: email, usage_count
591
+
592
+ # usage = {}
593
+ # for row in data:
594
+ # email = row.get("email", "").strip().lower()
595
+ # count = int(row.get("usage_count", 0))
596
+ # if email:
597
+ # usage[email] = count
598
+ # return usage
599
+ # except Exception as e:
600
+ # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
601
+ # return {}
602
+ # def load_user_usage():
603
+ # try:
604
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
605
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
606
+
607
+ # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
608
+ # if not found:
609
+ # return {} # not found, start fresh
610
+
611
+ # #file_id = found[0]["id"]
612
+ # file_id = found
613
+ # content = pipeline.download_drive_file_content(file_id)
614
+ # return json.loads(content.strip()) if content.strip() else {}
615
+
616
+ # except Exception as e:
617
+ # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
618
+ # return {}
619
+ def load_user_usage():
620
+ try:
621
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
622
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
623
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
624
+ client = gspread.authorize(creds)
625
+
626
+ sheet = client.open("user_usage_log").sheet1
627
+ data = sheet.get_all_values()
628
+ print("data: ", data)
629
+ print("🧪 Raw header row from sheet:", data[0])
630
+ print("🧪 Character codes in each header:")
631
+ for h in data[0]:
632
+ print([ord(c) for c in h])
633
+
634
+ if not data or len(data) < 2:
635
+ print("⚠️ Sheet is empty or missing rows.")
636
+ return {}
637
+
638
+ headers = [h.strip().lower() for h in data[0]]
639
+ if "email" not in headers or "usage_count" not in headers:
640
+ print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
641
+ return {}
642
+
643
+ permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
644
+ df = pd.DataFrame(data[1:], columns=headers)
645
+
646
+ usage = {}
647
+ permitted = {}
648
+ for _, row in df.iterrows():
649
+ email = row.get("email", "").strip().lower()
650
+ try:
651
+ #count = int(row.get("usage_count", 0))
652
+ try:
653
+ count = int(float(row.get("usage_count", 0)))
654
+ except Exception:
655
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
656
+ count = 0
657
+
658
+ if email:
659
+ usage[email] = count
660
+ if permitted_index is not None:
661
+ try:
662
+ permitted_count = int(float(row.get("permitted_samples", 50)))
663
+ permitted[email] = permitted_count
664
+ except:
665
+ permitted[email] = 50
666
+
667
+ except ValueError:
668
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
669
+ return usage, permitted
670
+
671
+ except Exception as e:
672
+ print(f"❌ Error in load_user_usage: {e}")
673
+ return {}, {}
674
+
675
+
676
+
677
+ # def save_user_usage(usage):
678
+ # with open(USER_USAGE_TRACK_FILE, "w") as f:
679
+ # json.dump(usage, f, indent=2)
680
+
681
+ # def save_user_usage(usage_dict):
682
+ # try:
683
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
684
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
685
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
686
+ # client = gspread.authorize(creds)
687
+
688
+ # sheet = client.open("user_usage_log").sheet1
689
+ # sheet.clear() # clear old contents first
690
+
691
+ # # Write header + rows
692
+ # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
693
+ # sheet.update(rows)
694
+ # except Exception as e:
695
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
696
+ # def save_user_usage(usage_dict):
697
+ # try:
698
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
699
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
700
+
701
+ # import tempfile
702
+ # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
703
+ # print("💾 Saving this usage dict:", usage_dict)
704
+ # with open(tmp_path, "w") as f:
705
+ # json.dump(usage_dict, f, indent=2)
706
+
707
+ # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
708
+
709
+ # except Exception as e:
710
+ # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
711
+ # def save_user_usage(usage_dict):
712
+ # try:
713
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
714
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
715
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
716
+ # client = gspread.authorize(creds)
717
+
718
+ # spreadsheet = client.open("user_usage_log")
719
+ # sheet = spreadsheet.sheet1
720
+
721
+ # # Step 1: Convert new usage to DataFrame
722
+ # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
723
+ # df_new["email"] = df_new["email"].str.strip().str.lower()
724
+
725
+ # # Step 2: Load existing data
726
+ # existing_data = sheet.get_all_values()
727
+ # print("🧪 Sheet existing_data:", existing_data)
728
+
729
+ # # Try to load old data
730
+ # if existing_data and len(existing_data[0]) >= 1:
731
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
732
+
733
+ # # Fix missing columns
734
+ # if "email" not in df_old.columns:
735
+ # df_old["email"] = ""
736
+ # if "usage_count" not in df_old.columns:
737
+ # df_old["usage_count"] = 0
738
+
739
+ # df_old["email"] = df_old["email"].str.strip().str.lower()
740
+ # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
741
+ # else:
742
+ # df_old = pd.DataFrame(columns=["email", "usage_count"])
743
+
744
+ # # Step 3: Merge
745
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True)
746
+ # df_combined = df_combined.groupby("email", as_index=False).sum()
747
+
748
+ # # Step 4: Write back
749
+ # sheet.clear()
750
+ # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
751
+ # print("✅ Saved user usage to user_usage_log sheet.")
752
+
753
+ # except Exception as e:
754
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
755
+ def save_user_usage(usage_dict):
756
+ try:
757
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
758
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
759
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
760
+ client = gspread.authorize(creds)
761
+
762
+ spreadsheet = client.open("user_usage_log")
763
+ sheet = spreadsheet.sheet1
764
+
765
+ # Build new df
766
+ df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
767
+ df_new["email"] = df_new["email"].str.strip().str.lower()
768
+ df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
769
+
770
+ # Read existing data
771
+ existing_data = sheet.get_all_values()
772
+ if existing_data and len(existing_data[0]) >= 2:
773
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
774
+ df_old["email"] = df_old["email"].str.strip().str.lower()
775
+ df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
776
+ else:
777
+ df_old = pd.DataFrame(columns=["email", "usage_count"])
778
+
779
+ # Overwrite specific emails only
780
+ df_old = df_old.set_index("email")
781
+ for email, count in usage_dict.items():
782
+ email = email.strip().lower()
783
+ df_old.loc[email, "usage_count"] = count
784
+ df_old = df_old.reset_index()
785
+
786
+ # Save
787
+ sheet.clear()
788
+ sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
789
+ print("✅ Saved user usage to user_usage_log sheet.")
790
+
791
+ except Exception as e:
792
+ print(f"❌ Failed to save user usage to Google Sheets: {e}")
793
+
794
+
795
+
796
+
797
+ # def increment_usage(user_id, num_samples=1):
798
+ # usage = load_user_usage()
799
+ # if user_id not in usage:
800
+ # usage[user_id] = 0
801
+ # usage[user_id] += num_samples
802
+ # save_user_usage(usage)
803
+ # return usage[user_id]
804
+ # def increment_usage(email: str, count: int):
805
+ # usage = load_user_usage()
806
+ # email_key = email.strip().lower()
807
+ # usage[email_key] = usage.get(email_key, 0) + count
808
+ # save_user_usage(usage)
809
+ # return usage[email_key]
810
+ def increment_usage(email: str, count: int = 1):
811
+ usage, permitted = load_user_usage()
812
+ email_key = email.strip().lower()
813
+ #usage[email_key] = usage.get(email_key, 0) + count
814
+ current = usage.get(email_key, 0)
815
+ new_value = current + count
816
+ max_allowed = permitted.get(email_key) or 50
817
+ usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
818
+ print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
819
+ print("max allow is: ", max_allowed)
820
+ save_user_usage(usage)
821
+ return usage[email_key], max_allowed
822
+
823
+
824
+ # run the batch
825
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
826
+ stop_flag=None, output_file_path=None,
827
+ limited_acc=50, yield_callback=None):
828
+ if user_email:
829
+ limited_acc += 10
830
+ accessions, error = extract_accessions_from_input(file, raw_text)
831
+ if error:
832
+ #return [], "", "", f"Error: {error}"
833
+ return [], f"Error: {error}", 0, "", ""
834
+ if resume_file:
835
+ accessions = get_incomplete_accessions(resume_file)
836
+ tmp_dir = tempfile.mkdtemp()
837
+ if not output_file_path:
838
+ if resume_file:
839
+ output_file_path = os.path.join(tmp_dir, resume_file)
840
+ else:
841
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
842
+
843
+ all_rows = []
844
+ # all_summaries = []
845
+ # all_flags = []
846
+ progress_lines = []
847
+ warning = ""
848
+ if len(accessions) > limited_acc:
849
+ accessions = accessions[:limited_acc]
850
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
851
+ for i, acc in enumerate(accessions):
852
+ if stop_flag and stop_flag.value:
853
+ line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
854
+ progress_lines.append(line)
855
+ if yield_callback:
856
+ yield_callback(line)
857
+ print("🛑 User requested stop.")
858
+ break
859
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
860
+ try:
861
+ # rows, summary, label, explain = summarize_results(acc)
862
+ rows = summarize_results(acc)
863
+ all_rows.extend(rows)
864
+ # all_summaries.append(f"**{acc}**\n{summary}")
865
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
866
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
867
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
868
+ line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
869
+ progress_lines.append(line)
870
+ if yield_callback:
871
+ yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
872
+ except Exception as e:
873
+ print(f" Failed to process {acc}: {e}")
874
+ continue
875
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
876
+ #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
877
+ limited_acc -= 1
878
+ """for row in all_rows:
879
+ source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
880
+
881
+ if source_column.startswith("http"): # Check if the source is a URL
882
+ # Wrap it with HTML anchor tags to make it clickable
883
+ row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
884
+ if not warning:
885
+ warning = f"You only have {limited_acc} left"
886
+ if user_email.strip():
887
+ user_hash = hash_user_id(user_email)
888
+ total_queries = increment_usage(user_hash, len(all_rows))
889
+ else:
890
+ total_queries = 0
891
+ yield_callback("✅ Finished!")
892
+
893
+ # summary_text = "\n\n---\n\n".join(all_summaries)
894
+ # flag_text = "\n\n---\n\n".join(all_flags)
895
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
896
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
897
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
mtdna_classifier.py CHANGED
@@ -1,707 +1,714 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- #import streamlit as st
5
- import subprocess
6
- import re
7
- from Bio import Entrez
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
-
20
- # Set your email (required by NCBI Entrez)
21
- #Entrez.email = "[email protected]"
22
- import nltk
23
-
24
- nltk.download("stopwords")
25
- nltk.download("punkt")
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
- from Bio import Entrez, Medline
29
- import re
30
-
31
- Entrez.email = "[email protected]"
32
-
33
- # --- Helper Functions (Re-organized and Upgraded) ---
34
-
35
- def fetch_ncbi_metadata(accession_number):
36
- """
37
- Fetches metadata directly from NCBI GenBank using Entrez.
38
- Includes robust error handling and improved field extraction.
39
- Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
- Also attempts to extract ethnicity and sample_type (ancient/modern).
41
-
42
- Args:
43
- accession_number (str): The NCBI accession number (e.g., "ON792208").
44
-
45
- Returns:
46
- dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
- 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
- """
49
- Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
-
51
- country = "unknown"
52
- specific_location = "unknown"
53
- ethnicity = "unknown"
54
- sample_type = "unknown"
55
- collection_date = "unknown"
56
- isolate = "unknown"
57
- title = "unknown"
58
- doi = "unknown"
59
- pubmed_id = None
60
- all_feature = "unknown"
61
-
62
- KNOWN_COUNTRIES = [
63
- "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
- "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
- "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
- "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
- "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
- "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
- "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
- "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
- "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
- "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
- "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
- "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
- "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
- "Yemen", "Zambia", "Zimbabwe"
77
- ]
78
- COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
-
80
- try:
81
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
- record = Entrez.read(handle)
83
- handle.close()
84
-
85
- gb_seq = None
86
- # Validate record structure: It should be a list with at least one element (a dict)
87
- if isinstance(record, list) and len(record) > 0:
88
- if isinstance(record[0], dict):
89
- gb_seq = record[0]
90
- else:
91
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
- else:
93
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
-
95
- # If gb_seq is still None, return defaults
96
- if gb_seq is None:
97
- return {"country": "unknown", "specific_location": "unknown", "ethnicity": "unknown",
98
- "sample_type": "unknown", "collection_date": "unknown", "isolate": "unknown",
99
- "title": "unknown", "doi": "unknown", "pubmed_id": None}
100
-
101
-
102
- # If gb_seq is valid, proceed with extraction
103
- collection_date = gb_seq.get("GBSeq_create-date","unknown")
104
-
105
- references = gb_seq.get("GBSeq_references", [])
106
- for ref in references:
107
- if not pubmed_id:
108
- pubmed_id = ref.get("GBReference_pubmed",None)
109
- if title == "unknown":
110
- title = ref.get("GBReference_title","unknown")
111
- for xref in ref.get("GBReference_xref", []):
112
- if xref.get("GBXref_dbname") == "doi":
113
- doi = xref.get("GBXref_id")
114
- break
115
-
116
- features = gb_seq.get("GBSeq_feature-table", [])
117
-
118
- context_for_flagging = "" # Accumulate text for ancient/modern detection
119
- features_context = ""
120
- for feature in features:
121
- if feature.get("GBFeature_key") == "source":
122
- feature_context = ""
123
- qualifiers = feature.get("GBFeature_quals", [])
124
- found_country = "unknown"
125
- found_specific_location = "unknown"
126
- found_ethnicity = "unknown"
127
-
128
- temp_geo_loc_name = "unknown"
129
- temp_note_origin_locality = "unknown"
130
- temp_country_qual = "unknown"
131
- temp_locality_qual = "unknown"
132
- temp_collection_location_qual = "unknown"
133
- temp_isolation_source_qual = "unknown"
134
- temp_env_sample_qual = "unknown"
135
- temp_pop_qual = "unknown"
136
- temp_organism_qual = "unknown"
137
- temp_specimen_qual = "unknown"
138
- temp_strain_qual = "unknown"
139
-
140
- for qual in qualifiers:
141
- qual_name = qual.get("GBQualifier_name")
142
- qual_value = qual.get("GBQualifier_value")
143
- feature_context += qual_name + ": " + qual_value +"\n"
144
- if qual_name == "collection_date":
145
- collection_date = qual_value
146
- elif qual_name == "isolate":
147
- isolate = qual_value
148
- elif qual_name == "population":
149
- temp_pop_qual = qual_value
150
- elif qual_name == "organism":
151
- temp_organism_qual = qual_value
152
- elif qual_name == "specimen_voucher" or qual_name == "specimen":
153
- temp_specimen_qual = qual_value
154
- elif qual_name == "strain":
155
- temp_strain_qual = qual_value
156
- elif qual_name == "isolation_source":
157
- temp_isolation_source_qual = qual_value
158
- elif qual_name == "environmental_sample":
159
- temp_env_sample_qual = qual_value
160
-
161
- if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
162
- elif qual_name == "note":
163
- if qual_value.startswith("origin_locality:"):
164
- temp_note_origin_locality = qual_value
165
- context_for_flagging += qual_value + " " # Capture all notes for flagging
166
- elif qual_name == "country": temp_country_qual = qual_value
167
- elif qual_name == "locality": temp_locality_qual = qual_value
168
- elif qual_name == "collection_location": temp_collection_location_qual = qual_value
169
-
170
-
171
- # --- Aggregate all relevant info into context_for_flagging ---
172
- context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
173
- context_for_flagging = context_for_flagging.strip()
174
-
175
- # --- Determine final country and specific_location based on priority ---
176
- if temp_geo_loc_name != "unknown":
177
- parts = [p.strip() for p in temp_geo_loc_name.split(':')]
178
- if len(parts) > 1:
179
- found_specific_location = parts[-1]; found_country = parts[0]
180
- else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
181
- elif temp_note_origin_locality != "unknown":
182
- match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
183
- if match:
184
- location_string = match.group(1).strip()
185
- parts = [p.strip() for p in location_string.split(':')]
186
- if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
187
- else: found_country = location_string; found_specific_location = "unknown"
188
- elif temp_locality_qual != "unknown":
189
- found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
190
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
191
- else: found_specific_location = temp_locality_qual; found_country = "unknown"
192
- elif temp_collection_location_qual != "unknown":
193
- found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
194
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
195
- else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
196
- elif temp_isolation_source_qual != "unknown":
197
- found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
198
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
199
- else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
200
- elif temp_env_sample_qual != "unknown":
201
- found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
202
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
203
- else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
204
- if found_country == "unknown" and temp_country_qual != "unknown":
205
- found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
206
- if found_country_match: found_country = found_country_match.group(1)
207
-
208
- country = found_country
209
- specific_location = found_specific_location
210
- # --- Determine final ethnicity ---
211
- if temp_pop_qual != "unknown":
212
- found_ethnicity = temp_pop_qual
213
- elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
214
- found_ethnicity = isolate
215
- elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
216
- eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
217
- if eth_match:
218
- found_ethnicity = eth_match.group(1).strip()
219
-
220
- ethnicity = found_ethnicity
221
-
222
- # --- Determine sample_type (ancient/modern) ---
223
- if context_for_flagging:
224
- sample_type, explain = detect_ancient_flag(context_for_flagging)
225
- features_context += feature_context + "\n"
226
- break
227
-
228
- if specific_location != "unknown" and specific_location.lower() == country.lower():
229
- specific_location = "unknown"
230
- if not features_context: features_context = "unknown"
231
- return {"country": country.lower(),
232
- "specific_location": specific_location.lower(),
233
- "ethnicity": ethnicity.lower(),
234
- "sample_type": sample_type.lower(),
235
- "collection_date": collection_date,
236
- "isolate": isolate,
237
- "title": title,
238
- "doi": doi,
239
- "pubmed_id": pubmed_id,
240
- "all_features": features_context}
241
-
242
- except Exception as e:
243
- print(f"Error fetching NCBI data for {accession_number}: {e}")
244
- return {"country": "unknown",
245
- "specific_location": "unknown",
246
- "ethnicity": "unknown",
247
- "sample_type": "unknown",
248
- "collection_date": "unknown",
249
- "isolate": "unknown",
250
- "title": "unknown",
251
- "doi": "unknown",
252
- "pubmed_id": None,
253
- "all_features": "unknown"}
254
-
255
- # --- Helper function for country matching (re-defined from main code to be self-contained) ---
256
- _country_keywords = {
257
- "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
258
- "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
259
- "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
260
- "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
261
- "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
262
- "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
263
- "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
264
- "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
265
- "central india": "India", "east india": "India", "northeast india": "India",
266
- "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
267
- "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
268
- }
269
-
270
- def get_country_from_text(text):
271
- text_lower = text.lower()
272
- for keyword, country in _country_keywords.items():
273
- if keyword in text_lower:
274
- return country
275
- return "unknown"
276
- # The result will be seen as manualLink for the function get_paper_text
277
- def search_google_custom(query, max_results=3):
278
- # query should be the title from ncbi or paper/source title
279
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
280
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
281
- endpoint = os.environ["SEARCH_ENDPOINT"]
282
- params = {
283
- "key": GOOGLE_CSE_API_KEY,
284
- "cx": GOOGLE_CSE_CX,
285
- "q": query,
286
- "num": max_results
287
- }
288
- try:
289
- response = requests.get(endpoint, params=params)
290
- if response.status_code == 429:
291
- print("Rate limit hit. Try again later.")
292
- return []
293
- response.raise_for_status()
294
- data = response.json().get("items", [])
295
- return [item.get("link") for item in data if item.get("link")]
296
- except Exception as e:
297
- print("Google CSE error:", e)
298
- return []
299
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
300
- # Step 3.1: Extract Text
301
- # sub: download excel file
302
- def download_excel_file(url, save_path="temp.xlsx"):
303
- if "view.officeapps.live.com" in url:
304
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
305
- real_url = urllib.parse.unquote(parsed_url["src"][0])
306
- response = requests.get(real_url)
307
- with open(save_path, "wb") as f:
308
- f.write(response.content)
309
- return save_path
310
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
311
- response = requests.get(url)
312
- response.raise_for_status() # Raises error if download fails
313
- with open(save_path, "wb") as f:
314
- f.write(response.content)
315
- return save_path
316
- else:
317
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
318
- return url
319
- def get_paper_text(doi,id,manualLinks=None):
320
- # create the temporary folder to contain the texts
321
- folder_path = Path("data/"+str(id))
322
- if not folder_path.exists():
323
- cmd = f'mkdir data/{id}'
324
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
325
- print("data/"+str(id) +" created.")
326
- else:
327
- print("data/"+str(id) +" already exists.")
328
- saveLinkFolder = "data/"+id
329
-
330
- link = 'https://doi.org/' + doi
331
- '''textsToExtract = { "doiLink":"paperText"
332
- "file1.pdf":"text1",
333
- "file2.doc":"text2",
334
- "file3.xlsx":excelText3'''
335
- textsToExtract = {}
336
- # get the file to create listOfFile for each id
337
- html = extractHTML.HTML("",link)
338
- jsonSM = html.getSupMaterial()
339
- text = ""
340
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
341
- if manualLinks != None:
342
- links += manualLinks
343
- for l in links:
344
- # get the main paper
345
- name = l.split("/")[-1]
346
- file_path = folder_path / name
347
- if l == link:
348
- text = html.getListSection()
349
- textsToExtract[link] = text
350
- elif l.endswith(".pdf"):
351
- if file_path.is_file():
352
- l = saveLinkFolder + "/" + name
353
- print("File exists.")
354
- p = pdf.PDF(l,saveLinkFolder,doi)
355
- f = p.openPDFFile()
356
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
357
- doc = fitz.open(pdf_path)
358
- text = "\n".join([page.get_text() for page in doc])
359
- textsToExtract[l] = text
360
- elif l.endswith(".doc") or l.endswith(".docx"):
361
- d = wordDoc.wordDoc(l,saveLinkFolder)
362
- text = d.extractTextByPage()
363
- textsToExtract[l] = text
364
- elif l.split(".")[-1].lower() in "xlsx":
365
- wc = word2vec.word2Vec()
366
- # download excel file if it not downloaded yet
367
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
368
- excelPath = download_excel_file(l, savePath)
369
- corpus = wc.tableTransformToCorpusText([],excelPath)
370
- text = ''
371
- for c in corpus:
372
- para = corpus[c]
373
- for words in para:
374
- text += " ".join(words)
375
- textsToExtract[l] = text
376
- # delete folder after finishing getting text
377
- #cmd = f'rm -r data/{id}'
378
- #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
379
- return textsToExtract
380
- # Step 3.2: Extract context
381
- def extract_context(text, keyword, window=500):
382
- # firstly try accession number
383
- idx = text.find(keyword)
384
- if idx == -1:
385
- return "Sample ID not found."
386
- return text[max(0, idx-window): idx+window]
387
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
388
- if keep_if is None:
389
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
390
-
391
- outputs = ""
392
- text = text.lower()
393
-
394
- # If isolate is provided, prioritize paragraphs that mention it
395
- # If isolate is provided, prioritize paragraphs that mention it
396
- if accession and accession.lower() in text:
397
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
398
- outputs += extract_context(text, accession.lower(), window=700)
399
- if isolate and isolate.lower() in text:
400
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
401
- outputs += extract_context(text, isolate.lower(), window=700)
402
- for keyword in keep_if:
403
- para = extract_context(text, keyword)
404
- if para and para not in outputs:
405
- outputs += para + "\n"
406
- return outputs
407
- # Step 4: Classification for now (demo purposes)
408
- # 4.1: Using a HuggingFace model (question-answering)
409
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
410
- try:
411
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
412
- result = qa({"context": context, "question": question})
413
- return result.get("answer", "Unknown")
414
- except Exception as e:
415
- return f"Error: {str(e)}"
416
-
417
- # 4.2: Infer from haplogroup
418
- # Load pre-trained spaCy model for NER
419
- try:
420
- nlp = spacy.load("en_core_web_sm")
421
- except OSError:
422
- download("en_core_web_sm")
423
- nlp = spacy.load("en_core_web_sm")
424
-
425
- # Define the haplogroup-to-region mapping (simple rule-based)
426
- import csv
427
-
428
- def load_haplogroup_mapping(csv_path):
429
- mapping = {}
430
- with open(csv_path) as f:
431
- reader = csv.DictReader(f)
432
- for row in reader:
433
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
434
- return mapping
435
-
436
- # Function to extract haplogroup from the text
437
- def extract_haplogroup(text):
438
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
439
- if match:
440
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
441
- if submatch:
442
- return submatch.group(0)
443
- else:
444
- return match.group(1) # fallback
445
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
446
- if fallback:
447
- return fallback.group(1)
448
- return None
449
-
450
-
451
- # Function to extract location based on NER
452
- def extract_location(text):
453
- doc = nlp(text)
454
- locations = []
455
- for ent in doc.ents:
456
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
457
- locations.append(ent.text)
458
- return locations
459
-
460
- # Function to infer location from haplogroup
461
- def infer_location_from_haplogroup(haplogroup):
462
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
463
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
464
-
465
- # Function to classify the mtDNA sample
466
- def classify_mtDNA_sample_from_haplo(text):
467
- # Extract haplogroup
468
- haplogroup = extract_haplogroup(text)
469
- # Extract location based on NER
470
- locations = extract_location(text)
471
- # Infer location based on haplogroup
472
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
473
- return {
474
- "source":sourceHaplo,
475
- "locations_found_in_context": locations,
476
- "haplogroup": haplogroup,
477
- "inferred_location": inferred_location
478
-
479
- }
480
- # 4.3 Get from available NCBI
481
- def infer_location_fromNCBI(accession):
482
- try:
483
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
484
- text = handle.read()
485
- handle.close()
486
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
487
- if match:
488
- return match.group(2), match.group(0) # This is the value like "Brunei"
489
- return "Not found", "Not found"
490
-
491
- except Exception as e:
492
- print("❌ Entrez error:", e)
493
- return "Not found", "Not found"
494
-
495
- ### ANCIENT/MODERN FLAG
496
- from Bio import Entrez
497
- import re
498
-
499
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
500
- """
501
- Try to classify a sample as Ancient or Modern using:
502
- 1. NCBI accession (if available)
503
- 2. Supplementary text or context fallback
504
- """
505
- context = ""
506
- label, explain = "", ""
507
-
508
- try:
509
- # Check if we can fetch metadata from NCBI using the accession
510
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
511
- text = handle.read()
512
- handle.close()
513
-
514
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
515
- if isolate_source:
516
- context += isolate_source.group(0) + " "
517
-
518
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
519
- if specimen:
520
- context += specimen.group(0) + " "
521
-
522
- if context.strip():
523
- label, explain = detect_ancient_flag(context)
524
- if label!="Unknown":
525
- return label, explain + " from NCBI\n(" + context + ")"
526
-
527
- # If no useful NCBI metadata, check supplementary texts
528
- if textsToExtract:
529
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
530
-
531
- for source in textsToExtract:
532
- text_block = textsToExtract[source]
533
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
534
- label, explain = detect_ancient_flag(context)
535
-
536
- if label == "Ancient":
537
- labels["ancient"][0] += 1
538
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
539
- elif label == "Modern":
540
- labels["modern"][0] += 1
541
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
542
- else:
543
- labels["unknown"] += 1
544
-
545
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
546
- if labels["modern"][0] > labels["ancient"][0]:
547
- return "Modern", labels["modern"][1]
548
- else:
549
- return "Ancient", labels["ancient"][1]
550
- else:
551
- return "Unknown", "No strong keywords detected"
552
- else:
553
- print("No DOI or PubMed ID available for inference.")
554
- return "", ""
555
-
556
- except Exception as e:
557
- print("Error:", e)
558
- return "", ""
559
-
560
-
561
- def detect_ancient_flag(context_snippet):
562
- context = context_snippet.lower()
563
-
564
- ancient_keywords = [
565
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
566
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
567
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
568
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
569
- ]
570
-
571
- modern_keywords = [
572
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
573
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
574
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
575
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
576
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
577
- ]
578
-
579
- ancient_hits = [k for k in ancient_keywords if k in context]
580
- modern_hits = [k for k in modern_keywords if k in context]
581
-
582
- if ancient_hits and not modern_hits:
583
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
584
- elif modern_hits and not ancient_hits:
585
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
586
- elif ancient_hits and modern_hits:
587
- if len(ancient_hits) >= len(modern_hits):
588
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
589
- else:
590
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
591
-
592
- # Fallback to QA
593
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
594
- if answer.startswith("Error"):
595
- return "Unknown", answer
596
- if "ancient" in answer.lower():
597
- return "Ancient", f"Leaning ancient based on QA: {answer}"
598
- elif "modern" in answer.lower():
599
- return "Modern", f"Leaning modern based on QA: {answer}"
600
- else:
601
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
602
-
603
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
604
- def classify_sample_location(accession):
605
- outputs = {}
606
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
607
- # Step 1: get pubmed id and isolate
608
- pubmedID, isolate = get_info_from_accession(accession)
609
- '''if not pubmedID:
610
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
611
- if not isolate:
612
- isolate = "UNKNOWN_ISOLATE"
613
- # Step 2: get doi
614
- doi = get_doi_from_pubmed_id(pubmedID)
615
- '''if not doi:
616
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
617
- # Step 3: get text
618
- '''textsToExtract = { "doiLink":"paperText"
619
- "file1.pdf":"text1",
620
- "file2.doc":"text2",
621
- "file3.xlsx":excelText3'''
622
- if doi and pubmedID:
623
- textsToExtract = get_paper_text(doi,pubmedID)
624
- else: textsToExtract = {}
625
- '''if not textsToExtract:
626
- return {"error": f"No texts extracted for DOI {doi}"}'''
627
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
628
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
629
- else:
630
- label, explain = flag_ancient_modern(accession,textsToExtract)
631
- # Step 4: prediction
632
- outputs[accession] = {}
633
- outputs[isolate] = {}
634
- # 4.0 Infer from NCBI
635
- location, outputNCBI = infer_location_fromNCBI(accession)
636
- NCBI_result = {
637
- "source": "NCBI",
638
- "sample_id": accession,
639
- "predicted_location": location,
640
- "context_snippet": outputNCBI}
641
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
642
- if textsToExtract:
643
- long_text = ""
644
- for key in textsToExtract:
645
- text = textsToExtract[key]
646
- # try accession number first
647
- outputs[accession][key] = {}
648
- keyword = accession
649
- context = extract_context(text, keyword, window=500)
650
- # 4.1: Using a HuggingFace model (question-answering)
651
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
652
- qa_result = {
653
- "source": key,
654
- "sample_id": keyword,
655
- "predicted_location": location,
656
- "context_snippet": context
657
- }
658
- outputs[keyword][key]["QAModel"] = qa_result
659
- # 4.2: Infer from haplogroup
660
- haplo_result = classify_mtDNA_sample_from_haplo(context)
661
- outputs[keyword][key]["haplogroup"] = haplo_result
662
- # try isolate
663
- keyword = isolate
664
- outputs[isolate][key] = {}
665
- context = extract_context(text, keyword, window=500)
666
- # 4.1.1: Using a HuggingFace model (question-answering)
667
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
668
- qa_result = {
669
- "source": key,
670
- "sample_id": keyword,
671
- "predicted_location": location,
672
- "context_snippet": context
673
- }
674
- outputs[keyword][key]["QAModel"] = qa_result
675
- # 4.2.1: Infer from haplogroup
676
- haplo_result = classify_mtDNA_sample_from_haplo(context)
677
- outputs[keyword][key]["haplogroup"] = haplo_result
678
- # add long text
679
- long_text += text + ". \n"
680
- # 4.3: UpgradeClassify
681
- # try sample_id as accession number
682
- sample_id = accession
683
- if sample_id:
684
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
685
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
686
- if locations!="No clear location found in top matches":
687
- outputs[sample_id]["upgradeClassifier"] = {}
688
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
689
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
690
- "sample_id": sample_id,
691
- "predicted_location": ", ".join(locations),
692
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
693
- }
694
- # try sample_id as isolate name
695
- sample_id = isolate
696
- if sample_id:
697
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
698
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
699
- if locations!="No clear location found in top matches":
700
- outputs[sample_id]["upgradeClassifier"] = {}
701
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
702
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
703
- "sample_id": sample_id,
704
- "predicted_location": ", ".join(locations),
705
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
706
- }
 
 
 
 
 
 
 
707
  return outputs, label, explain
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "[email protected]"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "[email protected]"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown",
98
+ "specific_location": "unknown",
99
+ "ethnicity": "unknown",
100
+ "sample_type": "unknown",
101
+ "collection_date": "unknown",
102
+ "isolate": "unknown",
103
+ "title": "unknown",
104
+ "doi": "unknown",
105
+ "pubmed_id": None,
106
+ "all_features": "unknown"}
107
+
108
+
109
+ # If gb_seq is valid, proceed with extraction
110
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
+
112
+ references = gb_seq.get("GBSeq_references", [])
113
+ for ref in references:
114
+ if not pubmed_id:
115
+ pubmed_id = ref.get("GBReference_pubmed",None)
116
+ if title == "unknown":
117
+ title = ref.get("GBReference_title","unknown")
118
+ for xref in ref.get("GBReference_xref", []):
119
+ if xref.get("GBXref_dbname") == "doi":
120
+ doi = xref.get("GBXref_id")
121
+ break
122
+
123
+ features = gb_seq.get("GBSeq_feature-table", [])
124
+
125
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
126
+ features_context = ""
127
+ for feature in features:
128
+ if feature.get("GBFeature_key") == "source":
129
+ feature_context = ""
130
+ qualifiers = feature.get("GBFeature_quals", [])
131
+ found_country = "unknown"
132
+ found_specific_location = "unknown"
133
+ found_ethnicity = "unknown"
134
+
135
+ temp_geo_loc_name = "unknown"
136
+ temp_note_origin_locality = "unknown"
137
+ temp_country_qual = "unknown"
138
+ temp_locality_qual = "unknown"
139
+ temp_collection_location_qual = "unknown"
140
+ temp_isolation_source_qual = "unknown"
141
+ temp_env_sample_qual = "unknown"
142
+ temp_pop_qual = "unknown"
143
+ temp_organism_qual = "unknown"
144
+ temp_specimen_qual = "unknown"
145
+ temp_strain_qual = "unknown"
146
+
147
+ for qual in qualifiers:
148
+ qual_name = qual.get("GBQualifier_name")
149
+ qual_value = qual.get("GBQualifier_value")
150
+ feature_context += qual_name + ": " + qual_value +"\n"
151
+ if qual_name == "collection_date":
152
+ collection_date = qual_value
153
+ elif qual_name == "isolate":
154
+ isolate = qual_value
155
+ elif qual_name == "population":
156
+ temp_pop_qual = qual_value
157
+ elif qual_name == "organism":
158
+ temp_organism_qual = qual_value
159
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
+ temp_specimen_qual = qual_value
161
+ elif qual_name == "strain":
162
+ temp_strain_qual = qual_value
163
+ elif qual_name == "isolation_source":
164
+ temp_isolation_source_qual = qual_value
165
+ elif qual_name == "environmental_sample":
166
+ temp_env_sample_qual = qual_value
167
+
168
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
+ elif qual_name == "note":
170
+ if qual_value.startswith("origin_locality:"):
171
+ temp_note_origin_locality = qual_value
172
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
173
+ elif qual_name == "country": temp_country_qual = qual_value
174
+ elif qual_name == "locality": temp_locality_qual = qual_value
175
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
+
177
+
178
+ # --- Aggregate all relevant info into context_for_flagging ---
179
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
+ context_for_flagging = context_for_flagging.strip()
181
+
182
+ # --- Determine final country and specific_location based on priority ---
183
+ if temp_geo_loc_name != "unknown":
184
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
+ if len(parts) > 1:
186
+ found_specific_location = parts[-1]; found_country = parts[0]
187
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
+ elif temp_note_origin_locality != "unknown":
189
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
+ if match:
191
+ location_string = match.group(1).strip()
192
+ parts = [p.strip() for p in location_string.split(':')]
193
+ if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
+ else: found_country = location_string; found_specific_location = "unknown"
195
+ elif temp_locality_qual != "unknown":
196
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
+ elif temp_collection_location_qual != "unknown":
200
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
+ elif temp_isolation_source_qual != "unknown":
204
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
+ elif temp_env_sample_qual != "unknown":
208
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
+ if found_country == "unknown" and temp_country_qual != "unknown":
212
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
+ if found_country_match: found_country = found_country_match.group(1)
214
+
215
+ country = found_country
216
+ specific_location = found_specific_location
217
+ # --- Determine final ethnicity ---
218
+ if temp_pop_qual != "unknown":
219
+ found_ethnicity = temp_pop_qual
220
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
+ found_ethnicity = isolate
222
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
+ if eth_match:
225
+ found_ethnicity = eth_match.group(1).strip()
226
+
227
+ ethnicity = found_ethnicity
228
+
229
+ # --- Determine sample_type (ancient/modern) ---
230
+ if context_for_flagging:
231
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
232
+ features_context += feature_context + "\n"
233
+ break
234
+
235
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
236
+ specific_location = "unknown"
237
+ if not features_context: features_context = "unknown"
238
+ return {"country": country.lower(),
239
+ "specific_location": specific_location.lower(),
240
+ "ethnicity": ethnicity.lower(),
241
+ "sample_type": sample_type.lower(),
242
+ "collection_date": collection_date,
243
+ "isolate": isolate,
244
+ "title": title,
245
+ "doi": doi,
246
+ "pubmed_id": pubmed_id,
247
+ "all_features": features_context}
248
+
249
+ except:
250
+ print(f"Error fetching NCBI data for {accession_number}")
251
+ return {"country": "unknown",
252
+ "specific_location": "unknown",
253
+ "ethnicity": "unknown",
254
+ "sample_type": "unknown",
255
+ "collection_date": "unknown",
256
+ "isolate": "unknown",
257
+ "title": "unknown",
258
+ "doi": "unknown",
259
+ "pubmed_id": None,
260
+ "all_features": "unknown"}
261
+
262
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
+ _country_keywords = {
264
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
+ "central india": "India", "east india": "India", "northeast india": "India",
273
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
+ }
276
+
277
+ def get_country_from_text(text):
278
+ text_lower = text.lower()
279
+ for keyword, country in _country_keywords.items():
280
+ if keyword in text_lower:
281
+ return country
282
+ return "unknown"
283
+ # The result will be seen as manualLink for the function get_paper_text
284
+ def search_google_custom(query, max_results=3):
285
+ # query should be the title from ncbi or paper/source title
286
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
+ endpoint = os.environ["SEARCH_ENDPOINT"]
289
+ params = {
290
+ "key": GOOGLE_CSE_API_KEY,
291
+ "cx": GOOGLE_CSE_CX,
292
+ "q": query,
293
+ "num": max_results
294
+ }
295
+ try:
296
+ response = requests.get(endpoint, params=params)
297
+ if response.status_code == 429:
298
+ print("Rate limit hit. Try again later.")
299
+ return []
300
+ response.raise_for_status()
301
+ data = response.json().get("items", [])
302
+ return [item.get("link") for item in data if item.get("link")]
303
+ except Exception as e:
304
+ print("Google CSE error:", e)
305
+ return []
306
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
307
+ # Step 3.1: Extract Text
308
+ # sub: download excel file
309
+ def download_excel_file(url, save_path="temp.xlsx"):
310
+ if "view.officeapps.live.com" in url:
311
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
312
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
313
+ response = requests.get(real_url)
314
+ with open(save_path, "wb") as f:
315
+ f.write(response.content)
316
+ return save_path
317
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
318
+ response = requests.get(url)
319
+ response.raise_for_status() # Raises error if download fails
320
+ with open(save_path, "wb") as f:
321
+ f.write(response.content)
322
+ return save_path
323
+ else:
324
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
325
+ return url
326
+ def get_paper_text(doi,id,manualLinks=None):
327
+ # create the temporary folder to contain the texts
328
+ folder_path = Path("data/"+str(id))
329
+ if not folder_path.exists():
330
+ cmd = f'mkdir data/{id}'
331
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
332
+ print("data/"+str(id) +" created.")
333
+ else:
334
+ print("data/"+str(id) +" already exists.")
335
+ saveLinkFolder = "data/"+id
336
+
337
+ link = 'https://doi.org/' + doi
338
+ '''textsToExtract = { "doiLink":"paperText"
339
+ "file1.pdf":"text1",
340
+ "file2.doc":"text2",
341
+ "file3.xlsx":excelText3'''
342
+ textsToExtract = {}
343
+ # get the file to create listOfFile for each id
344
+ html = extractHTML.HTML("",link)
345
+ jsonSM = html.getSupMaterial()
346
+ text = ""
347
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
348
+ if manualLinks != None:
349
+ links += manualLinks
350
+ for l in links:
351
+ # get the main paper
352
+ name = l.split("/")[-1]
353
+ file_path = folder_path / name
354
+ if l == link:
355
+ text = html.getListSection()
356
+ textsToExtract[link] = text
357
+ elif l.endswith(".pdf"):
358
+ if file_path.is_file():
359
+ l = saveLinkFolder + "/" + name
360
+ print("File exists.")
361
+ p = pdf.PDF(l,saveLinkFolder,doi)
362
+ f = p.openPDFFile()
363
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
364
+ doc = fitz.open(pdf_path)
365
+ text = "\n".join([page.get_text() for page in doc])
366
+ textsToExtract[l] = text
367
+ elif l.endswith(".doc") or l.endswith(".docx"):
368
+ d = wordDoc.wordDoc(l,saveLinkFolder)
369
+ text = d.extractTextByPage()
370
+ textsToExtract[l] = text
371
+ elif l.split(".")[-1].lower() in "xlsx":
372
+ wc = word2vec.word2Vec()
373
+ # download excel file if it not downloaded yet
374
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
375
+ excelPath = download_excel_file(l, savePath)
376
+ corpus = wc.tableTransformToCorpusText([],excelPath)
377
+ text = ''
378
+ for c in corpus:
379
+ para = corpus[c]
380
+ for words in para:
381
+ text += " ".join(words)
382
+ textsToExtract[l] = text
383
+ # delete folder after finishing getting text
384
+ #cmd = f'rm -r data/{id}'
385
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
386
+ return textsToExtract
387
+ # Step 3.2: Extract context
388
+ def extract_context(text, keyword, window=500):
389
+ # firstly try accession number
390
+ idx = text.find(keyword)
391
+ if idx == -1:
392
+ return "Sample ID not found."
393
+ return text[max(0, idx-window): idx+window]
394
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
395
+ if keep_if is None:
396
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
397
+
398
+ outputs = ""
399
+ text = text.lower()
400
+
401
+ # If isolate is provided, prioritize paragraphs that mention it
402
+ # If isolate is provided, prioritize paragraphs that mention it
403
+ if accession and accession.lower() in text:
404
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
405
+ outputs += extract_context(text, accession.lower(), window=700)
406
+ if isolate and isolate.lower() in text:
407
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
408
+ outputs += extract_context(text, isolate.lower(), window=700)
409
+ for keyword in keep_if:
410
+ para = extract_context(text, keyword)
411
+ if para and para not in outputs:
412
+ outputs += para + "\n"
413
+ return outputs
414
+ # Step 4: Classification for now (demo purposes)
415
+ # 4.1: Using a HuggingFace model (question-answering)
416
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
417
+ try:
418
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
419
+ result = qa({"context": context, "question": question})
420
+ return result.get("answer", "Unknown")
421
+ except Exception as e:
422
+ return f"Error: {str(e)}"
423
+
424
+ # 4.2: Infer from haplogroup
425
+ # Load pre-trained spaCy model for NER
426
+ try:
427
+ nlp = spacy.load("en_core_web_sm")
428
+ except OSError:
429
+ download("en_core_web_sm")
430
+ nlp = spacy.load("en_core_web_sm")
431
+
432
+ # Define the haplogroup-to-region mapping (simple rule-based)
433
+ import csv
434
+
435
+ def load_haplogroup_mapping(csv_path):
436
+ mapping = {}
437
+ with open(csv_path) as f:
438
+ reader = csv.DictReader(f)
439
+ for row in reader:
440
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
441
+ return mapping
442
+
443
+ # Function to extract haplogroup from the text
444
+ def extract_haplogroup(text):
445
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
446
+ if match:
447
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
448
+ if submatch:
449
+ return submatch.group(0)
450
+ else:
451
+ return match.group(1) # fallback
452
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
453
+ if fallback:
454
+ return fallback.group(1)
455
+ return None
456
+
457
+
458
+ # Function to extract location based on NER
459
+ def extract_location(text):
460
+ doc = nlp(text)
461
+ locations = []
462
+ for ent in doc.ents:
463
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
464
+ locations.append(ent.text)
465
+ return locations
466
+
467
+ # Function to infer location from haplogroup
468
+ def infer_location_from_haplogroup(haplogroup):
469
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
470
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
471
+
472
+ # Function to classify the mtDNA sample
473
+ def classify_mtDNA_sample_from_haplo(text):
474
+ # Extract haplogroup
475
+ haplogroup = extract_haplogroup(text)
476
+ # Extract location based on NER
477
+ locations = extract_location(text)
478
+ # Infer location based on haplogroup
479
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
480
+ return {
481
+ "source":sourceHaplo,
482
+ "locations_found_in_context": locations,
483
+ "haplogroup": haplogroup,
484
+ "inferred_location": inferred_location
485
+
486
+ }
487
+ # 4.3 Get from available NCBI
488
+ def infer_location_fromNCBI(accession):
489
+ try:
490
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
491
+ text = handle.read()
492
+ handle.close()
493
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
494
+ if match:
495
+ return match.group(2), match.group(0) # This is the value like "Brunei"
496
+ return "Not found", "Not found"
497
+
498
+ except Exception as e:
499
+ print("❌ Entrez error:", e)
500
+ return "Not found", "Not found"
501
+
502
+ ### ANCIENT/MODERN FLAG
503
+ from Bio import Entrez
504
+ import re
505
+
506
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
507
+ """
508
+ Try to classify a sample as Ancient or Modern using:
509
+ 1. NCBI accession (if available)
510
+ 2. Supplementary text or context fallback
511
+ """
512
+ context = ""
513
+ label, explain = "", ""
514
+
515
+ try:
516
+ # Check if we can fetch metadata from NCBI using the accession
517
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
518
+ text = handle.read()
519
+ handle.close()
520
+
521
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
522
+ if isolate_source:
523
+ context += isolate_source.group(0) + " "
524
+
525
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
526
+ if specimen:
527
+ context += specimen.group(0) + " "
528
+
529
+ if context.strip():
530
+ label, explain = detect_ancient_flag(context)
531
+ if label!="Unknown":
532
+ return label, explain + " from NCBI\n(" + context + ")"
533
+
534
+ # If no useful NCBI metadata, check supplementary texts
535
+ if textsToExtract:
536
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
537
+
538
+ for source in textsToExtract:
539
+ text_block = textsToExtract[source]
540
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
541
+ label, explain = detect_ancient_flag(context)
542
+
543
+ if label == "Ancient":
544
+ labels["ancient"][0] += 1
545
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
546
+ elif label == "Modern":
547
+ labels["modern"][0] += 1
548
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
549
+ else:
550
+ labels["unknown"] += 1
551
+
552
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
553
+ if labels["modern"][0] > labels["ancient"][0]:
554
+ return "Modern", labels["modern"][1]
555
+ else:
556
+ return "Ancient", labels["ancient"][1]
557
+ else:
558
+ return "Unknown", "No strong keywords detected"
559
+ else:
560
+ print("No DOI or PubMed ID available for inference.")
561
+ return "", ""
562
+
563
+ except Exception as e:
564
+ print("Error:", e)
565
+ return "", ""
566
+
567
+
568
+ def detect_ancient_flag(context_snippet):
569
+ context = context_snippet.lower()
570
+
571
+ ancient_keywords = [
572
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
573
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
574
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
575
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
576
+ ]
577
+
578
+ modern_keywords = [
579
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
580
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
581
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
582
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
583
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
584
+ ]
585
+
586
+ ancient_hits = [k for k in ancient_keywords if k in context]
587
+ modern_hits = [k for k in modern_keywords if k in context]
588
+
589
+ if ancient_hits and not modern_hits:
590
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
591
+ elif modern_hits and not ancient_hits:
592
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
593
+ elif ancient_hits and modern_hits:
594
+ if len(ancient_hits) >= len(modern_hits):
595
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
596
+ else:
597
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
598
+
599
+ # Fallback to QA
600
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
601
+ if answer.startswith("Error"):
602
+ return "Unknown", answer
603
+ if "ancient" in answer.lower():
604
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
605
+ elif "modern" in answer.lower():
606
+ return "Modern", f"Leaning modern based on QA: {answer}"
607
+ else:
608
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
609
+
610
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
611
+ def classify_sample_location(accession):
612
+ outputs = {}
613
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
614
+ # Step 1: get pubmed id and isolate
615
+ pubmedID, isolate = get_info_from_accession(accession)
616
+ '''if not pubmedID:
617
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
618
+ if not isolate:
619
+ isolate = "UNKNOWN_ISOLATE"
620
+ # Step 2: get doi
621
+ doi = get_doi_from_pubmed_id(pubmedID)
622
+ '''if not doi:
623
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
624
+ # Step 3: get text
625
+ '''textsToExtract = { "doiLink":"paperText"
626
+ "file1.pdf":"text1",
627
+ "file2.doc":"text2",
628
+ "file3.xlsx":excelText3'''
629
+ if doi and pubmedID:
630
+ textsToExtract = get_paper_text(doi,pubmedID)
631
+ else: textsToExtract = {}
632
+ '''if not textsToExtract:
633
+ return {"error": f"No texts extracted for DOI {doi}"}'''
634
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
635
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
636
+ else:
637
+ label, explain = flag_ancient_modern(accession,textsToExtract)
638
+ # Step 4: prediction
639
+ outputs[accession] = {}
640
+ outputs[isolate] = {}
641
+ # 4.0 Infer from NCBI
642
+ location, outputNCBI = infer_location_fromNCBI(accession)
643
+ NCBI_result = {
644
+ "source": "NCBI",
645
+ "sample_id": accession,
646
+ "predicted_location": location,
647
+ "context_snippet": outputNCBI}
648
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
649
+ if textsToExtract:
650
+ long_text = ""
651
+ for key in textsToExtract:
652
+ text = textsToExtract[key]
653
+ # try accession number first
654
+ outputs[accession][key] = {}
655
+ keyword = accession
656
+ context = extract_context(text, keyword, window=500)
657
+ # 4.1: Using a HuggingFace model (question-answering)
658
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
659
+ qa_result = {
660
+ "source": key,
661
+ "sample_id": keyword,
662
+ "predicted_location": location,
663
+ "context_snippet": context
664
+ }
665
+ outputs[keyword][key]["QAModel"] = qa_result
666
+ # 4.2: Infer from haplogroup
667
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
668
+ outputs[keyword][key]["haplogroup"] = haplo_result
669
+ # try isolate
670
+ keyword = isolate
671
+ outputs[isolate][key] = {}
672
+ context = extract_context(text, keyword, window=500)
673
+ # 4.1.1: Using a HuggingFace model (question-answering)
674
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
675
+ qa_result = {
676
+ "source": key,
677
+ "sample_id": keyword,
678
+ "predicted_location": location,
679
+ "context_snippet": context
680
+ }
681
+ outputs[keyword][key]["QAModel"] = qa_result
682
+ # 4.2.1: Infer from haplogroup
683
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
684
+ outputs[keyword][key]["haplogroup"] = haplo_result
685
+ # add long text
686
+ long_text += text + ". \n"
687
+ # 4.3: UpgradeClassify
688
+ # try sample_id as accession number
689
+ sample_id = accession
690
+ if sample_id:
691
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
692
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
693
+ if locations!="No clear location found in top matches":
694
+ outputs[sample_id]["upgradeClassifier"] = {}
695
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
696
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
697
+ "sample_id": sample_id,
698
+ "predicted_location": ", ".join(locations),
699
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
700
+ }
701
+ # try sample_id as isolate name
702
+ sample_id = isolate
703
+ if sample_id:
704
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
705
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
706
+ if locations!="No clear location found in top matches":
707
+ outputs[sample_id]["upgradeClassifier"] = {}
708
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
709
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
710
+ "sample_id": sample_id,
711
+ "predicted_location": ", ".join(locations),
712
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
713
+ }
714
  return outputs, label, explain
offer.html ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div style="font-family: sans-serif; line-height: 1.6;">
3
+ <h1>mtDNA Location Classifier</h1>
4
+
5
+ <h2>Publish faster. Analyze smarter.</h2>
6
+ <p>Are you dealing with <strong>incomplete mtDNA metadata</strong> (like country, ethnicity, sample type)?<br>
7
+ This tool helps researchers like you generate <strong>clean, structured labels</strong> — ready to use for your paper.</p>
8
+
9
+ <hr>
10
+
11
+ <h2> What You’ll Get:</h2>
12
+ <ul>
13
+ <li>Inference from sequence ID alone</li>
14
+ <li>Handles hard edge cases</li>
15
+ <li>Clear sample type, country, and more (ethnicity, phenotype, etc.)</li>
16
+ <li>Excel export with citations</li>
17
+ <li>Feedback-based refund policy</li>
18
+ </ul>
19
+
20
+ <hr>
21
+
22
+ <h2>Free Tier</h2>
23
+ <ul>
24
+ <li>30 free samples — no email needed</li>
25
+ <li>+20 bonus samples + Excel file when you enter your email</li>
26
+ <li>Don’t like the result? Tell us why on the report — we won’t count the bad ones (email required)</li>
27
+ </ul>
28
+
29
+ <hr>
30
+
31
+ <h2>Pricing — Pay As You Go (DIY)</h2>
32
+ <table border="1" cellpadding="6" cellspacing="0">
33
+ <thead>
34
+ <tr>
35
+ <th>Case Type</th>
36
+ <th>Price/Sample</th>
37
+ <th>Output</th>
38
+ </tr>
39
+ </thead>
40
+ <tbody>
41
+ <tr><td>Normal</td><td>$0.10</td><td>Sample Type + Country</td></tr>
42
+ <tr><td>Edge</td><td>$1.00</td><td>Sample Type + Country</td></tr>
43
+ <tr><td>Niche</td><td>$2.00</td><td>Sample Type + Country + 1 Custom Label</td></tr>
44
+ </tbody>
45
+ </table>
46
+
47
+ <hr>
48
+
49
+ <h2>Batch Discount (1000+ Samples)</h2>
50
+ <ul>
51
+ <li><strong>Normal Output</strong> → $100 total ($0.10/sample)<br>Unsatisfied samples? We’ll refund them.</li>
52
+ <li><strong>Niche Output</strong> → $500 total ($0.50/sample)<br>Includes an extra label like ethnicity or phenotype.</li>
53
+ </ul>
54
+
55
+ <hr>
56
+
57
+ <h2>Early User Bonus (Limited!)</h2>
58
+ <p>Are you one of our <strong>first 10 paying users</strong>?<br>
59
+ Just type <code>early_user</code> in your email.</p>
60
+ <p>You'll get <strong>20% lifetime discount</strong> on every plan — forever.<br>
61
+ We’ll apply this automatically so you don’t have to calculate anything.</p>
62
+
63
+ <hr>
64
+
65
+ <h2>Our Mission</h2>
66
+ <p>Give clean, high-quality, open-access biological datasets that save researchers time and improve scientific reproducibility.<br>
67
+ Build the world's clean, AI-driven open metadata source for biological research.</p>
68
+
69
+ <hr>
70
+
71
+ <h2>Try It Now</h2>
72
+ <p>Paste your sequence ID on our demo:<br>
73
+ <a href="https://huggingface.co/spaces/VyLala/mtDNALocation" target="_blank">Try the Classifier</a></p>
74
+ <p>Need help or bulk analysis?<br>
75
+ <a href="mailto:[email protected]" target="_blank">Contact Us</a></p>
76
+
77
+ </div>
pipeline.py CHANGED
@@ -1,649 +1,649 @@
1
- # test1: MJ17 direct
2
- # test2: "A1YU101" thailand cross-ref
3
- # test3: "EBK109" thailand cross-ref
4
- # test4: "OQ731952"/"BST115" for search query title: "South Asian maternal and paternal lineages in southern Thailand and"
5
- import data_preprocess
6
- import model
7
- import mtdna_classifier
8
- #import app
9
- import smart_fallback
10
- import pandas as pd
11
- from pathlib import Path
12
- import subprocess
13
- from NER.html import extractHTML
14
- import os
15
- import google.generativeai as genai
16
- import re
17
- import standardize_location
18
- # Helper functions in for this pipeline
19
- # Track time
20
- import time
21
- import multiprocessing
22
- import gspread
23
- from googleapiclient.discovery import build
24
- from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
25
- from google.oauth2.service_account import Credentials
26
- from oauth2client.service_account import ServiceAccountCredentials
27
- import io
28
- import json
29
- #––– Authentication setup –––
30
- GDRIVE_PARENT_FOLDER_NAME = "mtDNA-Location-Classifier"
31
- GDRIVE_DATA_FOLDER_NAME = os.environ["GDRIVE_DATA_FOLDER_NAME"]
32
- GCP_CREDS_DICT = json.loads(os.environ["GCP_CREDS_JSON"]) # from HF secrets
33
- GDRIVE_CREDS = Credentials.from_service_account_info(GCP_CREDS_DICT, scopes=["https://www.googleapis.com/auth/drive"])
34
- drive_service = build("drive", "v3", credentials=GDRIVE_CREDS)
35
-
36
- def get_or_create_drive_folder(name, parent_id=None):
37
- query = f"name='{name}' and mimeType='application/vnd.google-apps.folder'"
38
- if parent_id:
39
- query += f" and '{parent_id}' in parents"
40
- results = drive_service.files().list(q=query, spaces='drive', fields="files(id, name)").execute()
41
- items = results.get("files", [])
42
- if items:
43
- return items[0]["id"]
44
- file_metadata = {
45
- "name": name,
46
- "mimeType": "application/vnd.google-apps.folder"
47
- }
48
- if parent_id:
49
- file_metadata["parents"] = [parent_id]
50
- file = drive_service.files().create(body=file_metadata, fields="id").execute()
51
- return file["id"]
52
- def find_drive_file(filename, parent_id):
53
- """
54
- Checks if a file with the given name exists inside the specified Google Drive folder.
55
- Returns the file ID if found, else None.
56
- """
57
- query = f"'{parent_id}' in parents and name = '{filename}' and trashed = false"
58
- results = drive_service.files().list(q=query, spaces='drive', fields='files(id, name)', pageSize=1).execute()
59
- files = results.get('files', [])
60
- if files:
61
- return files[0]["id"]
62
- return None
63
-
64
-
65
- # def upload_file_to_drive(local_path, remote_name, folder_id):
66
- # file_metadata = {"name": remote_name, "parents": [folder_id]}
67
- # media = MediaFileUpload(local_path, resumable=True)
68
- # existing = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute().get("files", [])
69
- # if existing:
70
- # drive_service.files().delete(fileId=existing[0]["id"]).execute()
71
- # file = drive_service.files().create(body=file_metadata, media_body=media, fields="id").execute()
72
- # result = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute()
73
- # if not result.get("files"):
74
- # print(f"❌ Upload failed: File '{remote_name}' not found in folder after upload.")
75
- # else:
76
- # print(f"✅ Verified upload: {remote_name}")
77
- # return file["id"]
78
- def upload_file_to_drive(local_path, remote_name, folder_id):
79
- try:
80
- if not os.path.exists(local_path):
81
- raise FileNotFoundError(f"❌ Local file does not exist: {local_path}")
82
-
83
- # Delete existing file on Drive if present
84
- existing = drive_service.files().list(
85
- q=f"name='{remote_name}' and '{folder_id}' in parents and trashed = false",
86
- fields="files(id)"
87
- ).execute().get("files", [])
88
-
89
- if existing:
90
- drive_service.files().delete(fileId=existing[0]["id"]).execute()
91
- print(f"🗑️ Deleted existing '{remote_name}' in Drive folder {folder_id}")
92
-
93
- file_metadata = {"name": remote_name, "parents": [folder_id]}
94
- media = MediaFileUpload(local_path, resumable=True)
95
- file = drive_service.files().create(
96
- body=file_metadata,
97
- media_body=media,
98
- fields="id"
99
- ).execute()
100
-
101
- print(f"✅ Uploaded '{remote_name}' to Google Drive folder ID: {folder_id}")
102
- return file["id"]
103
-
104
- except Exception as e:
105
- print(f"❌ Error during upload: {e}")
106
- return None
107
-
108
-
109
- def download_file_from_drive(remote_name, folder_id, local_path):
110
- results = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute()
111
- files = results.get("files", [])
112
- if not files:
113
- return False
114
- file_id = files[0]["id"]
115
- request = drive_service.files().get_media(fileId=file_id)
116
- fh = io.FileIO(local_path, 'wb')
117
- downloader = MediaIoBaseDownload(fh, request)
118
- done = False
119
- while not done:
120
- _, done = downloader.next_chunk()
121
- return True
122
- def download_drive_file_content(file_id):
123
- request = drive_service.files().get_media(fileId=file_id)
124
- fh = io.BytesIO()
125
- downloader = MediaIoBaseDownload(fh, request)
126
- done = False
127
- while not done:
128
- _, done = downloader.next_chunk()
129
- fh.seek(0)
130
- return fh.read().decode("utf-8")
131
-
132
- # def run_with_timeout(func, args=(), kwargs={}, timeout=20):
133
- # """
134
- # Runs `func` with timeout in seconds. Kills if it exceeds.
135
- # Returns: (success, result or None)
136
- # """
137
- # def wrapper(q, *args, **kwargs):
138
- # try:
139
- # q.put(func(*args, **kwargs))
140
- # except Exception as e:
141
- # q.put(e)
142
-
143
- # q = multiprocessing.Queue()
144
- # p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs)
145
- # p.start()
146
- # p.join(timeout)
147
-
148
- # if p.is_alive():
149
- # p.terminate()
150
- # p.join()
151
- # print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
152
- # return False, None
153
- # else:
154
- # result = q.get()
155
- # if isinstance(result, Exception):
156
- # raise result
157
- # return True, result
158
- def run_with_timeout(func, args=(), kwargs={}, timeout=30):
159
- import concurrent.futures
160
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
161
- future = executor.submit(func, *args, **kwargs)
162
- try:
163
- return True, future.result(timeout=timeout)
164
- except concurrent.futures.TimeoutError:
165
- print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
166
- return False, None
167
-
168
- def time_it(func, *args, **kwargs):
169
- """
170
- Measure how long a function takes to run and return its result + time.
171
- """
172
- start = time.time()
173
- result = func(*args, **kwargs)
174
- end = time.time()
175
- elapsed = end - start
176
- print(f"⏱️ '{func.__name__}' took {elapsed:.3f} seconds")
177
- return result, elapsed
178
- # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
179
- def track_gemini_cost():
180
- # Prices are per 1,000 tokens
181
- PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
182
- PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
183
- PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
184
- return True
185
-
186
- def unique_preserve_order(seq):
187
- seen = set()
188
- return [x for x in seq if not (x in seen or seen.add(x))]
189
- # Main execution
190
- def pipeline_with_gemini(accessions):
191
- # output: country, sample_type, ethnic, location, money_cost, time_cost, explain
192
- # there can be one accession number in the accessions
193
- # Prices are per 1,000 tokens
194
- PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
195
- PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
196
- PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
197
- if not accessions:
198
- print("no input")
199
- return None
200
- else:
201
- accs_output = {}
202
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
203
- for acc in accessions:
204
- print("start gemini: ", acc)
205
- start = time.time()
206
- total_cost_title = 0
207
- jsonSM, links, article_text = {},[], ""
208
- acc_score = { "isolate": "",
209
- "country":{},
210
- "sample_type":{},
211
- #"specific_location":{},
212
- #"ethnicity":{},
213
- "query_cost":total_cost_title,
214
- "time_cost":None,
215
- "source":links}
216
- meta = mtdna_classifier.fetch_ncbi_metadata(acc)
217
- country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
218
- acc_score["isolate"] = iso
219
- print(meta)
220
- meta_expand = smart_fallback.fetch_ncbi(acc)
221
- # set up step: create the folder to save document
222
- chunk, all_output = "",""
223
- if pudID:
224
- id = str(pudID)
225
- saveTitle = title
226
- else:
227
- try:
228
- author_name = meta_expand["authors"].split(',')[0] # Use last name only
229
- except:
230
- author_name = meta_expand["authors"]
231
- saveTitle = title + "_" + col_date + "_" + author_name
232
- id = "DirectSubmission"
233
- # folder_path = Path("/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id))
234
- # if not folder_path.exists():
235
- # cmd = f'mkdir /content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/{id}'
236
- # result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
237
- # print("data/"+str(id) +" created.")
238
- # else:
239
- # print("data/"+str(id) +" already exists.")
240
- # saveLinkFolder = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id)
241
- # parent_folder_id = get_or_create_drive_folder(GDRIVE_PARENT_FOLDER_NAME)
242
- # data_folder_id = get_or_create_drive_folder(GDRIVE_DATA_FOLDER_NAME, parent_id=parent_folder_id)
243
- # sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
244
- data_folder_id = GDRIVE_DATA_FOLDER_NAME # Use the shared folder directly
245
- sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
246
- print("sample folder id: ", sample_folder_id)
247
-
248
- # Define document names
249
- if len(saveTitle) > 50:
250
- saveName = saveTitle[:50]
251
- saveName = saveName.replace(" ", "_")
252
- chunk_filename = f"{saveName}_merged_document.docx"
253
- all_filename = f"{saveName}_all_merged_document.docx"
254
- else:
255
- saveName = saveTitle.replace(" ", "_")
256
- chunk_filename = f"{saveName}_merged_document.docx"
257
- all_filename = f"{saveName}_all_merged_document.docx"
258
- print(chunk_filename, all_filename)
259
- # Define local temp paths for reading/writing
260
- # import tempfile
261
- # tmp_dir = tempfile.mkdtemp()
262
- LOCAL_TEMP_DIR = "/mnt/data/generated_docs"
263
- os.makedirs(LOCAL_TEMP_DIR, exist_ok=True)
264
- file_chunk_path = os.path.join(LOCAL_TEMP_DIR, chunk_filename)
265
- file_all_path = os.path.join(LOCAL_TEMP_DIR, all_filename)
266
- # file_chunk_path = os.path.join(tempfile.gettempdir(), chunk_filename)
267
- # file_all_path = os.path.join(tempfile.gettempdir(), all_filename)
268
- print(file_chunk_path)
269
- chunk_id = find_drive_file(chunk_filename, sample_folder_id)
270
- all_id = find_drive_file(all_filename, sample_folder_id)
271
-
272
- if chunk_id and all_id:
273
- print("✅ Files already exist in Google Drive. Downloading them...")
274
- chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
275
- all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
276
-
277
- # Read and parse these into `chunk` and `all_output`
278
- else:
279
- # 🔥 Remove any stale local copies
280
- if os.path.exists(file_chunk_path):
281
- os.remove(file_chunk_path)
282
- print(f"🗑️ Removed stale: {file_chunk_path}")
283
- if os.path.exists(file_all_path):
284
- os.remove(file_all_path)
285
- print(f"🗑️ Removed stale: {file_all_path}")
286
- # 🔥 Remove the local file first if it exists
287
- # if os.path.exists(file_chunk_path):
288
- # os.remove(file_chunk_path)
289
- # print("remove chunk path")
290
- # if os.path.exists(file_all_path):
291
- # os.remove(file_all_path)
292
- # print("remove all path")
293
- # Try to download if already exists on Drive
294
- chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
295
- all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
296
- print("chunk exist: ", chunk_exists)
297
- # first way: ncbi method
298
- print("country.lower: ",country.lower())
299
- if country.lower() != "unknown":
300
- stand_country = standardize_location.smart_country_lookup(country.lower())
301
- print("stand_country: ", stand_country)
302
- if stand_country.lower() != "not found":
303
- acc_score["country"][stand_country.lower()] = ["ncbi"]
304
- else: acc_score["country"][country.lower()] = ["ncbi"]
305
- # if spe_loc.lower() != "unknown":
306
- # acc_score["specific_location"][spe_loc.lower()] = ["ncbi"]
307
- # if ethnic.lower() != "unknown":
308
- # acc_score["ethnicity"][ethnic.lower()] = ["ncbi"]
309
- if sample_type.lower() != "unknown":
310
- acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
311
- # second way: LLM model
312
- # Preprocess the input token
313
- print(acc_score)
314
- accession, isolate = None, None
315
- if acc != "unknown": accession = acc
316
- if iso != "unknown": isolate = iso
317
- # check doi first
318
- if doi != "unknown":
319
- link = 'https://doi.org/' + doi
320
- # get the file to create listOfFile for each id
321
- print("link of doi: ", link)
322
- html = extractHTML.HTML("",link)
323
- jsonSM = html.getSupMaterial()
324
- article_text = html.getListSection()
325
- if article_text:
326
- if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower():
327
- links.append(link)
328
- if jsonSM:
329
- links += sum((jsonSM[key] for key in jsonSM),[])
330
- # no doi then google custom search api
331
- if len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower():
332
- # might find the article
333
- print("no article text")
334
- #tem_links = mtdna_classifier.search_google_custom(title, 2)
335
- tem_links = smart_fallback.smart_google_search(meta_expand)
336
- # get supplementary of that article
337
- print("tem links length ", len(tem_links))
338
- for link in tem_links:
339
- print("link in tem: ", link)
340
- html = extractHTML.HTML("",link)
341
- print("html yeh")
342
- jsonSM = html.getSupMaterial()
343
- print("jsonsm yeah")
344
- article_text_tem = html.getListSection()
345
- print(article_text_tem, jsonSM)
346
- if article_text_tem:
347
- if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text_tem.lower() or "403 Forbidden Request".lower() not in article_text_tem.lower():
348
- links.append(link)
349
- print("tem link appended ", link)
350
- if jsonSM:
351
- links += sum((jsonSM[key] for key in jsonSM),[])
352
- print("this is links: ",links)
353
- links = unique_preserve_order(links)
354
- acc_score["source"] = links
355
- # chunk_path = "/"+saveTitle+"_merged_document.docx"
356
- # all_path = "/"+saveTitle+"_all_merged_document.docx"
357
- # # if chunk and all output not exist yet
358
- # file_chunk_path = saveLinkFolder + chunk_path
359
- # file_all_path = saveLinkFolder + all_path
360
- # if os.path.exists(file_chunk_path):
361
- # print("File chunk exists!")
362
- # if not chunk:
363
- # text, table, document_title = model.read_docx_text(file_chunk_path)
364
- # chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
365
- # if os.path.exists(file_all_path):
366
- # print("File all output exists!")
367
- # if not all_output:
368
- # text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
369
- # all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
370
- if chunk_exists:
371
- print("File chunk exists!")
372
- if not chunk:
373
- print("start to get chunk")
374
- text, table, document_title = model.read_docx_text(file_chunk_path)
375
- chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
376
- if all_exists:
377
- print("File all output exists!")
378
- if not all_output:
379
- text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
380
- all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
381
- if not chunk and not all_output:
382
- print("not chunk and all output")
383
- # else: check if we can reuse these chunk and all output of existed accession to find another
384
- if links:
385
- for link in links:
386
- print(link)
387
- # if len(all_output) > 1000*1000:
388
- # all_output = data_preprocess.normalize_for_overlap(all_output)
389
- # print("after normalizing all output: ", len(all_output))
390
- if len(data_preprocess.normalize_for_overlap(all_output)) > 600000:
391
- print("break here")
392
- break
393
- if iso != "unknown": query_kw = iso
394
- else: query_kw = acc
395
- #text_link, tables_link, final_input_link = data_preprocess.preprocess_document(link,saveLinkFolder, isolate=query_kw)
396
- success_process, output_process = run_with_timeout(data_preprocess.preprocess_document,args=(link,sample_folder_id),kwargs={"isolate":query_kw},timeout=180)
397
- if success_process:
398
- text_link, tables_link, final_input_link = output_process[0], output_process[1], output_process[2]
399
- print("yes succeed for process document")
400
- else: text_link, tables_link, final_input_link = "", "", ""
401
- context = data_preprocess.extract_context(final_input_link, query_kw)
402
- if context != "Sample ID not found.":
403
- if len(data_preprocess.normalize_for_overlap(chunk)) < 1000*1000:
404
- success_chunk, the_output_chunk = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(chunk, context))
405
- if success_chunk:
406
- chunk = the_output_chunk#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
407
- print("yes succeed for chunk")
408
- else:
409
- chunk += context
410
- print("len context: ", len(context))
411
- print("basic fall back")
412
- print("len chunk after: ", len(chunk))
413
- if len(final_input_link) > 1000*1000:
414
- if context != "Sample ID not found.":
415
- final_input_link = context
416
- else:
417
- final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
418
- if len(final_input_link) > 1000 *1000:
419
- final_input_link = final_input_link[:100000]
420
- if len(data_preprocess.normalize_for_overlap(all_output)) < 1000*1000:
421
- success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link))
422
- if success:
423
- all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
424
- print("yes succeed")
425
- else:
426
- all_output += final_input_link
427
- print("len final input: ", len(final_input_link))
428
- print("basic fall back")
429
- print("len all output after: ", len(all_output))
430
- #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
431
-
432
- else:
433
- chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
434
- all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
435
- if not chunk: chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
436
- if not all_output: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
437
- if len(all_output) > 1*1024*1024:
438
- all_output = data_preprocess.normalize_for_overlap(all_output)
439
- if len(all_output) > 1*1024*1024:
440
- all_output = all_output[:1*1024*1024]
441
- print("chunk len: ", len(chunk))
442
- print("all output len: ", len(all_output))
443
- data_preprocess.save_text_to_docx(chunk, file_chunk_path)
444
- data_preprocess.save_text_to_docx(all_output, file_all_path)
445
- # Later when saving new files
446
- # data_preprocess.save_text_to_docx(chunk, chunk_filename, sample_folder_id)
447
- # data_preprocess.save_text_to_docx(all_output, all_filename, sample_folder_id)
448
-
449
- # Upload to Drive
450
- result_chunk_upload = upload_file_to_drive(file_chunk_path, chunk_filename, sample_folder_id)
451
- result_all_upload = upload_file_to_drive(file_all_path, all_filename, sample_folder_id)
452
- print("UPLOAD RESULT FOR CHUNK: ", result_chunk_upload)
453
- print(f"🔗 Uploaded file: https://drive.google.com/file/d/{result_chunk_upload}/view")
454
- print("here 1")
455
-
456
- # else:
457
- # final_input = ""
458
- # if all_output:
459
- # final_input = all_output
460
- # else:
461
- # if chunk: final_input = chunk
462
- # #data_preprocess.merge_texts_skipping_overlap(final_input, all_output)
463
- # if final_input:
464
- # keywords = []
465
- # if iso != "unknown": keywords.append(iso)
466
- # if acc != "unknown": keywords.append(acc)
467
- # for keyword in keywords:
468
- # chunkBFS = data_preprocess.get_contextual_sentences_BFS(final_input, keyword)
469
- # countryDFS, chunkDFS = data_preprocess.get_contextual_sentences_DFS(final_input, keyword)
470
- # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkDFS)
471
- # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkBFS)
472
-
473
- # Define paths for cached RAG assets
474
- # faiss_index_path = saveLinkFolder+"/faiss_index.bin"
475
- # document_chunks_path = saveLinkFolder+"/document_chunks.json"
476
- # structured_lookup_path = saveLinkFolder+"/structured_lookup.json"
477
- print("here 2")
478
- faiss_filename = "faiss_index.bin"
479
- chunks_filename = "document_chunks.json"
480
- lookup_filename = "structured_lookup.json"
481
- print("name of faiss: ", faiss_filename)
482
-
483
- faiss_index_path = os.path.join(LOCAL_TEMP_DIR, faiss_filename)
484
- document_chunks_path = os.path.join(LOCAL_TEMP_DIR, chunks_filename)
485
- structured_lookup_path = os.path.join(LOCAL_TEMP_DIR, lookup_filename)
486
- print("name if faiss path: ", faiss_index_path)
487
- # 🔥 Remove the local file first if it exists
488
- faiss_id = find_drive_file(faiss_filename, sample_folder_id)
489
- document_id = find_drive_file(chunks_filename, sample_folder_id)
490
- structure_id = find_drive_file(lookup_filename, sample_folder_id)
491
- if faiss_id and document_id and structure_id:
492
- print("✅ 3 Files already exist in Google Drive. Downloading them...")
493
- download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
494
- download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
495
- download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
496
- # Read and parse these into `chunk` and `all_output`
497
- else:
498
- if os.path.exists(faiss_index_path):
499
- os.remove(faiss_index_path)
500
- if os.path.exists(document_chunks_path):
501
- os.remove(document_chunks_path)
502
- if os.path.exists(structured_lookup_path):
503
- os.remove(structured_lookup_path)
504
- download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
505
- download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
506
- download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
507
-
508
- print("move to load rag")
509
- master_structured_lookup, faiss_index, document_chunks = model.load_rag_assets(
510
- faiss_index_path, document_chunks_path, structured_lookup_path
511
- )
512
-
513
- global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest')
514
- if not all_output:
515
- if chunk: all_output = chunk
516
- else: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
517
- if faiss_index is None:
518
- print("\nBuilding RAG assets (structured lookup, FAISS index, chunks)...")
519
- total_doc_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(
520
- all_output
521
- ).total_tokens
522
-
523
- initial_embedding_cost = (total_doc_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
524
- total_cost_title += initial_embedding_cost
525
- print(f"Initial one-time embedding cost for '{file_all_path}' ({total_doc_embedding_tokens} tokens): ${initial_embedding_cost:.6f}")
526
-
527
-
528
- master_structured_lookup, faiss_index, document_chunks, plain_text_content = model.build_vector_index_and_data(
529
- file_all_path, faiss_index_path, document_chunks_path, structured_lookup_path
530
- )
531
- else:
532
- print("\nRAG assets loaded from file. No re-embedding of entire document will occur.")
533
- plain_text_content_all, table_strings_all, document_title_all = model.read_docx_text(file_all_path)
534
- master_structured_lookup['document_title'] = master_structured_lookup.get('document_title', document_title_all)
535
-
536
- primary_word = iso
537
- alternative_word = acc
538
- print(f"\n--- General Query: Primary='{primary_word}' (Alternative='{alternative_word}') ---")
539
- if features.lower() not in all_output.lower():
540
- all_output += ". NCBI Features: " + features
541
- # country, sample_type, method_used, ethnic, spe_loc, total_query_cost = model.query_document_info(
542
- # primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
543
- # model.call_llm_api, chunk=chunk, all_output=all_output)
544
- print("this is chunk for the model")
545
- print(chunk)
546
- print("this is all output for the model")
547
- print(all_output)
548
- country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
549
- primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
550
- model.call_llm_api, chunk=chunk, all_output=all_output)
551
- print("country using ai: ", country)
552
- print("sample type using ai: ", sample_type)
553
- if len(country) == 0: country = "unknown"
554
- if len(sample_type) == 0: sample_type = "unknown"
555
- if country_explanation: country_explanation = "-"+country_explanation
556
- else: country_explanation = ""
557
- if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
558
- else: sample_type_explanation = ""
559
- if method_used == "unknown": method_used = ""
560
- if country.lower() != "unknown":
561
- stand_country = standardize_location.smart_country_lookup(country.lower())
562
- if stand_country.lower() != "not found":
563
- if stand_country.lower() in acc_score["country"]:
564
- if country_explanation:
565
- acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
566
- else:
567
- acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
568
- else:
569
- if country.lower() in acc_score["country"]:
570
- if country_explanation:
571
- if len(method_used + country_explanation) > 0:
572
- acc_score["country"][country.lower()].append(method_used + country_explanation)
573
- else:
574
- if len(method_used + country_explanation) > 0:
575
- acc_score["country"][country.lower()] = [method_used + country_explanation]
576
- # if spe_loc.lower() != "unknown":
577
- # if spe_loc.lower() in acc_score["specific_location"]:
578
- # acc_score["specific_location"][spe_loc.lower()].append(method_used)
579
- # else:
580
- # acc_score["specific_location"][spe_loc.lower()] = [method_used]
581
- # if ethnic.lower() != "unknown":
582
- # if ethnic.lower() in acc_score["ethnicity"]:
583
- # acc_score["ethnicity"][ethnic.lower()].append(method_used)
584
- # else:
585
- # acc_score["ethnicity"][ethnic.lower()] = [method_used]
586
- if sample_type.lower() != "unknown":
587
- if sample_type.lower() in acc_score["sample_type"]:
588
- if len(method_used + sample_type_explanation) > 0:
589
- acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
590
- else:
591
- if len(method_used + sample_type_explanation)> 0:
592
- acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
593
- # last resort: combine all information to give all output otherwise unknown
594
- if len(acc_score["country"]) == 0 or len(acc_score["sample_type"]) == 0:
595
- text = ""
596
- for key in meta_expand:
597
- text += str(key) + ": " + meta_expand[key] + "\n"
598
- if len(data_preprocess.normalize_for_overlap(all_output)) > 0:
599
- text += data_preprocess.normalize_for_overlap(all_output)
600
- if len(data_preprocess.normalize_for_overlap(chunk)) > 0:
601
- text += data_preprocess.normalize_for_overlap(chunk)
602
- text += ". NCBI Features: " + features
603
- print("this is text for the last resort model")
604
- print(text)
605
- country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
606
- primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
607
- model.call_llm_api, chunk=text, all_output=text)
608
- print("this is last resort results: ")
609
- print("country: ", country)
610
- print("sample type: ", sample_type)
611
- if len(country) == 0: country = "unknown"
612
- if len(sample_type) == 0: sample_type = "unknown"
613
- if country_explanation: country_explanation = "-"+country_explanation
614
- else: country_explanation = ""
615
- if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
616
- else: sample_type_explanation = ""
617
- if method_used == "unknown": method_used = ""
618
- if country.lower() != "unknown":
619
- stand_country = standardize_location.smart_country_lookup(country.lower())
620
- if stand_country.lower() != "not found":
621
- if stand_country.lower() in acc_score["country"]:
622
- if country_explanation:
623
- acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
624
- else:
625
- acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
626
- else:
627
- if country.lower() in acc_score["country"]:
628
- if country_explanation:
629
- if len(method_used + country_explanation) > 0:
630
- acc_score["country"][country.lower()].append(method_used + country_explanation)
631
- else:
632
- if len(method_used + country_explanation) > 0:
633
- acc_score["country"][country.lower()] = [method_used + country_explanation]
634
- if sample_type.lower() != "unknown":
635
- if sample_type.lower() in acc_score["sample_type"]:
636
- if len(method_used + sample_type_explanation) > 0:
637
- acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
638
- else:
639
- if len(method_used + sample_type_explanation)> 0:
640
- acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
641
- end = time.time()
642
- total_cost_title += total_query_cost
643
- acc_score["query_cost"] = f"{total_cost_title:.6f}"
644
- elapsed = end - start
645
- acc_score["time_cost"] = f"{elapsed:.3f} seconds"
646
- accs_output[acc] = acc_score
647
- print(accs_output[acc])
648
-
649
  return accs_output
 
1
+ # test1: MJ17 direct
2
+ # test2: "A1YU101" thailand cross-ref
3
+ # test3: "EBK109" thailand cross-ref
4
+ # test4: "OQ731952"/"BST115" for search query title: "South Asian maternal and paternal lineages in southern Thailand and"
5
+ import data_preprocess
6
+ import model
7
+ import mtdna_classifier
8
+ #import app
9
+ import smart_fallback
10
+ import pandas as pd
11
+ from pathlib import Path
12
+ import subprocess
13
+ from NER.html import extractHTML
14
+ import os
15
+ import google.generativeai as genai
16
+ import re
17
+ import standardize_location
18
+ # Helper functions in for this pipeline
19
+ # Track time
20
+ import time
21
+ import multiprocessing
22
+ import gspread
23
+ from googleapiclient.discovery import build
24
+ from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
25
+ from google.oauth2.service_account import Credentials
26
+ from oauth2client.service_account import ServiceAccountCredentials
27
+ import io
28
+ import json
29
+ #––– Authentication setup –––
30
+ GDRIVE_PARENT_FOLDER_NAME = "mtDNA-Location-Classifier"
31
+ GDRIVE_DATA_FOLDER_NAME = os.environ["GDRIVE_DATA_FOLDER_NAME"]
32
+ GCP_CREDS_DICT = json.loads(os.environ["GCP_CREDS_JSON"]) # from HF secrets
33
+ GDRIVE_CREDS = Credentials.from_service_account_info(GCP_CREDS_DICT, scopes=["https://www.googleapis.com/auth/drive"])
34
+ drive_service = build("drive", "v3", credentials=GDRIVE_CREDS)
35
+
36
+ def get_or_create_drive_folder(name, parent_id=None):
37
+ query = f"name='{name}' and mimeType='application/vnd.google-apps.folder'"
38
+ if parent_id:
39
+ query += f" and '{parent_id}' in parents"
40
+ results = drive_service.files().list(q=query, spaces='drive', fields="files(id, name)").execute()
41
+ items = results.get("files", [])
42
+ if items:
43
+ return items[0]["id"]
44
+ file_metadata = {
45
+ "name": name,
46
+ "mimeType": "application/vnd.google-apps.folder"
47
+ }
48
+ if parent_id:
49
+ file_metadata["parents"] = [parent_id]
50
+ file = drive_service.files().create(body=file_metadata, fields="id").execute()
51
+ return file["id"]
52
+ def find_drive_file(filename, parent_id):
53
+ """
54
+ Checks if a file with the given name exists inside the specified Google Drive folder.
55
+ Returns the file ID if found, else None.
56
+ """
57
+ query = f"'{parent_id}' in parents and name = '{filename}' and trashed = false"
58
+ results = drive_service.files().list(q=query, spaces='drive', fields='files(id, name)', pageSize=1).execute()
59
+ files = results.get('files', [])
60
+ if files:
61
+ return files[0]["id"]
62
+ return None
63
+
64
+
65
+ # def upload_file_to_drive(local_path, remote_name, folder_id):
66
+ # file_metadata = {"name": remote_name, "parents": [folder_id]}
67
+ # media = MediaFileUpload(local_path, resumable=True)
68
+ # existing = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute().get("files", [])
69
+ # if existing:
70
+ # drive_service.files().delete(fileId=existing[0]["id"]).execute()
71
+ # file = drive_service.files().create(body=file_metadata, media_body=media, fields="id").execute()
72
+ # result = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute()
73
+ # if not result.get("files"):
74
+ # print(f"❌ Upload failed: File '{remote_name}' not found in folder after upload.")
75
+ # else:
76
+ # print(f"✅ Verified upload: {remote_name}")
77
+ # return file["id"]
78
+ def upload_file_to_drive(local_path, remote_name, folder_id):
79
+ try:
80
+ if not os.path.exists(local_path):
81
+ raise FileNotFoundError(f"❌ Local file does not exist: {local_path}")
82
+
83
+ # Delete existing file on Drive if present
84
+ existing = drive_service.files().list(
85
+ q=f"name='{remote_name}' and '{folder_id}' in parents and trashed = false",
86
+ fields="files(id)"
87
+ ).execute().get("files", [])
88
+
89
+ if existing:
90
+ drive_service.files().delete(fileId=existing[0]["id"]).execute()
91
+ print(f"🗑️ Deleted existing '{remote_name}' in Drive folder {folder_id}")
92
+
93
+ file_metadata = {"name": remote_name, "parents": [folder_id]}
94
+ media = MediaFileUpload(local_path, resumable=True)
95
+ file = drive_service.files().create(
96
+ body=file_metadata,
97
+ media_body=media,
98
+ fields="id"
99
+ ).execute()
100
+
101
+ print(f"✅ Uploaded '{remote_name}' to Google Drive folder ID: {folder_id}")
102
+ return file["id"]
103
+
104
+ except Exception as e:
105
+ print(f"❌ Error during upload: {e}")
106
+ return None
107
+
108
+
109
+ def download_file_from_drive(remote_name, folder_id, local_path):
110
+ results = drive_service.files().list(q=f"name='{remote_name}' and '{folder_id}' in parents", fields="files(id)").execute()
111
+ files = results.get("files", [])
112
+ if not files:
113
+ return False
114
+ file_id = files[0]["id"]
115
+ request = drive_service.files().get_media(fileId=file_id)
116
+ fh = io.FileIO(local_path, 'wb')
117
+ downloader = MediaIoBaseDownload(fh, request)
118
+ done = False
119
+ while not done:
120
+ _, done = downloader.next_chunk()
121
+ return True
122
+ def download_drive_file_content(file_id):
123
+ request = drive_service.files().get_media(fileId=file_id)
124
+ fh = io.BytesIO()
125
+ downloader = MediaIoBaseDownload(fh, request)
126
+ done = False
127
+ while not done:
128
+ _, done = downloader.next_chunk()
129
+ fh.seek(0)
130
+ return fh.read().decode("utf-8")
131
+
132
+ # def run_with_timeout(func, args=(), kwargs={}, timeout=20):
133
+ # """
134
+ # Runs `func` with timeout in seconds. Kills if it exceeds.
135
+ # Returns: (success, result or None)
136
+ # """
137
+ # def wrapper(q, *args, **kwargs):
138
+ # try:
139
+ # q.put(func(*args, **kwargs))
140
+ # except Exception as e:
141
+ # q.put(e)
142
+
143
+ # q = multiprocessing.Queue()
144
+ # p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs)
145
+ # p.start()
146
+ # p.join(timeout)
147
+
148
+ # if p.is_alive():
149
+ # p.terminate()
150
+ # p.join()
151
+ # print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
152
+ # return False, None
153
+ # else:
154
+ # result = q.get()
155
+ # if isinstance(result, Exception):
156
+ # raise result
157
+ # return True, result
158
+ def run_with_timeout(func, args=(), kwargs={}, timeout=30):
159
+ import concurrent.futures
160
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
161
+ future = executor.submit(func, *args, **kwargs)
162
+ try:
163
+ return True, future.result(timeout=timeout)
164
+ except concurrent.futures.TimeoutError:
165
+ print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
166
+ return False, None
167
+
168
+ def time_it(func, *args, **kwargs):
169
+ """
170
+ Measure how long a function takes to run and return its result + time.
171
+ """
172
+ start = time.time()
173
+ result = func(*args, **kwargs)
174
+ end = time.time()
175
+ elapsed = end - start
176
+ print(f"⏱️ '{func.__name__}' took {elapsed:.3f} seconds")
177
+ return result, elapsed
178
+ # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
179
+ def track_gemini_cost():
180
+ # Prices are per 1,000 tokens
181
+ PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
182
+ PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
183
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
184
+ return True
185
+
186
+ def unique_preserve_order(seq):
187
+ seen = set()
188
+ return [x for x in seq if not (x in seen or seen.add(x))]
189
+ # Main execution
190
+ def pipeline_with_gemini(accessions):
191
+ # output: country, sample_type, ethnic, location, money_cost, time_cost, explain
192
+ # there can be one accession number in the accessions
193
+ # Prices are per 1,000 tokens
194
+ PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
195
+ PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
196
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
197
+ if not accessions:
198
+ print("no input")
199
+ return None
200
+ else:
201
+ accs_output = {}
202
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
203
+ for acc in accessions:
204
+ print("start gemini: ", acc)
205
+ start = time.time()
206
+ total_cost_title = 0
207
+ jsonSM, links, article_text = {},[], ""
208
+ acc_score = { "isolate": "",
209
+ "country":{},
210
+ "sample_type":{},
211
+ #"specific_location":{},
212
+ #"ethnicity":{},
213
+ "query_cost":total_cost_title,
214
+ "time_cost":None,
215
+ "source":links}
216
+ meta = mtdna_classifier.fetch_ncbi_metadata("unknown")
217
+ country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
218
+ acc_score["isolate"] = iso
219
+ print("meta: ",meta)
220
+ meta_expand = smart_fallback.fetch_ncbi("unknown")
221
+ print("meta expand: ", meta_expand)
222
+ # set up step: create the folder to save document
223
+ chunk, all_output = "",""
224
+ if pudID:
225
+ id = str(pudID)
226
+ saveTitle = title
227
+ else:
228
+ try:
229
+ author_name = meta_expand["authors"].split(',')[0] # Use last name only
230
+ except:
231
+ author_name = meta_expand["authors"]
232
+ saveTitle = title + "_" + col_date + "_" + author_name
233
+ id = "DirectSubmission"
234
+ # folder_path = Path("/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id))
235
+ # if not folder_path.exists():
236
+ # cmd = f'mkdir /content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/{id}'
237
+ # result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
238
+ # print("data/"+str(id) +" created.")
239
+ # else:
240
+ # print("data/"+str(id) +" already exists.")
241
+ # saveLinkFolder = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id)
242
+ # parent_folder_id = get_or_create_drive_folder(GDRIVE_PARENT_FOLDER_NAME)
243
+ # data_folder_id = get_or_create_drive_folder(GDRIVE_DATA_FOLDER_NAME, parent_id=parent_folder_id)
244
+ # sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
245
+ data_folder_id = GDRIVE_DATA_FOLDER_NAME # Use the shared folder directly
246
+ sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
247
+ print("sample folder id: ", sample_folder_id)
248
+
249
+ # Define document names
250
+ if len(saveTitle) > 50:
251
+ saveName = saveTitle[:50]
252
+ saveName = saveName.replace(" ", "_")
253
+ chunk_filename = f"{saveName}_merged_document.docx"
254
+ all_filename = f"{saveName}_all_merged_document.docx"
255
+ else:
256
+ saveName = saveTitle.replace(" ", "_")
257
+ chunk_filename = f"{saveName}_merged_document.docx"
258
+ all_filename = f"{saveName}_all_merged_document.docx"
259
+ print(chunk_filename, all_filename)
260
+ # Define local temp paths for reading/writing
261
+ # import tempfile
262
+ # tmp_dir = tempfile.mkdtemp()
263
+ LOCAL_TEMP_DIR = "/mnt/data/generated_docs"
264
+ os.makedirs(LOCAL_TEMP_DIR, exist_ok=True)
265
+ file_chunk_path = os.path.join(LOCAL_TEMP_DIR, chunk_filename)
266
+ file_all_path = os.path.join(LOCAL_TEMP_DIR, all_filename)
267
+ # file_chunk_path = os.path.join(tempfile.gettempdir(), chunk_filename)
268
+ # file_all_path = os.path.join(tempfile.gettempdir(), all_filename)
269
+ print(file_chunk_path)
270
+ chunk_id = find_drive_file(chunk_filename, sample_folder_id)
271
+ all_id = find_drive_file(all_filename, sample_folder_id)
272
+
273
+ if chunk_id and all_id:
274
+ print("✅ Files already exist in Google Drive. Downloading them...")
275
+ chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
276
+ all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
277
+ print("chunk_id and all_id: ")
278
+ print(chunk_id, all_id)
279
+ file = drive_service.files().get(fileId="1LUJRTrq8yt4S4lLwCvTmlxaKqpr0nvEn", fields="id, name, parents, webViewLink").execute()
280
+ print("📄 Name:", file["name"])
281
+ print("📁 Parent folder ID:", file["parents"][0])
282
+ print("🔗 View link:", file["webViewLink"])
283
+
284
+
285
+ # Read and parse these into `chunk` and `all_output`
286
+ else:
287
+ # 🔥 Remove any stale local copies
288
+ if os.path.exists(file_chunk_path):
289
+ os.remove(file_chunk_path)
290
+ print(f"🗑️ Removed stale: {file_chunk_path}")
291
+ if os.path.exists(file_all_path):
292
+ os.remove(file_all_path)
293
+ print(f"🗑️ Removed stale: {file_all_path}")
294
+ # 🔥 Remove the local file first if it exists
295
+ # if os.path.exists(file_chunk_path):
296
+ # os.remove(file_chunk_path)
297
+ # print("remove chunk path")
298
+ # if os.path.exists(file_all_path):
299
+ # os.remove(file_all_path)
300
+ # print("remove all path")
301
+ # Try to download if already exists on Drive
302
+ chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
303
+ all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
304
+ print("chunk exist: ", chunk_exists)
305
+ # first way: ncbi method
306
+ print("country.lower: ",country.lower())
307
+ if country.lower() != "unknown":
308
+ stand_country = standardize_location.smart_country_lookup(country.lower())
309
+ print("stand_country: ", stand_country)
310
+ if stand_country.lower() != "not found":
311
+ acc_score["country"][stand_country.lower()] = ["ncbi"]
312
+ else: acc_score["country"][country.lower()] = ["ncbi"]
313
+ # if spe_loc.lower() != "unknown":
314
+ # acc_score["specific_location"][spe_loc.lower()] = ["ncbi"]
315
+ # if ethnic.lower() != "unknown":
316
+ # acc_score["ethnicity"][ethnic.lower()] = ["ncbi"]
317
+ if sample_type.lower() != "unknown":
318
+ acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
319
+ # second way: LLM model
320
+ # Preprocess the input token
321
+ print(acc_score)
322
+ accession, isolate = None, None
323
+ if acc != "unknown": accession = acc
324
+ if iso != "unknown": isolate = iso
325
+ # check doi first
326
+ if doi != "unknown":
327
+ link = 'https://doi.org/' + doi
328
+ # get the file to create listOfFile for each id
329
+ print("link of doi: ", link)
330
+ html = extractHTML.HTML("",link)
331
+ jsonSM = html.getSupMaterial()
332
+ article_text = html.getListSection()
333
+ if article_text:
334
+ if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower():
335
+ links.append(link)
336
+ if jsonSM:
337
+ links += sum((jsonSM[key] for key in jsonSM),[])
338
+ # no doi then google custom search api
339
+ if doi=="unknown" or len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower():
340
+ # might find the article
341
+ print("no article text, start tem link")
342
+ #tem_links = mtdna_classifier.search_google_custom(title, 2)
343
+ tem_links = smart_fallback.smart_google_search(meta_expand)
344
+ print("tem links: ", tem_links)
345
+ tem_link_acc = smart_fallback.google_accession_search(acc)
346
+ tem_links += tem_link_acc
347
+ tem_links = unique_preserve_order(tem_links)
348
+ print("tem link before filtering: ", tem_links)
349
+ # filter the quality link
350
+ print("saveLinkFolder as sample folder id: ", sample_folder_id)
351
+ links = smart_fallback.filter_links_by_metadata(tem_links, saveLinkFolder=sample_folder_id, accession=acc)
352
+ print("this is links: ",links)
353
+ links = unique_preserve_order(links)
354
+ acc_score["source"] = links
355
+ # chunk_path = "/"+saveTitle+"_merged_document.docx"
356
+ # all_path = "/"+saveTitle+"_all_merged_document.docx"
357
+ # # if chunk and all output not exist yet
358
+ # file_chunk_path = saveLinkFolder + chunk_path
359
+ # file_all_path = saveLinkFolder + all_path
360
+ # if os.path.exists(file_chunk_path):
361
+ # print("File chunk exists!")
362
+ # if not chunk:
363
+ # text, table, document_title = model.read_docx_text(file_chunk_path)
364
+ # chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
365
+ # if os.path.exists(file_all_path):
366
+ # print("File all output exists!")
367
+ # if not all_output:
368
+ # text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
369
+ # all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
370
+ if chunk_exists:
371
+ print("File chunk exists!")
372
+ if not chunk:
373
+ print("start to get chunk")
374
+ text, table, document_title = model.read_docx_text(file_chunk_path)
375
+ chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
376
+ if all_exists:
377
+ print("File all output exists!")
378
+ if not all_output:
379
+ text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
380
+ all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
381
+ if not chunk and not all_output:
382
+ print("not chunk and all output")
383
+ # else: check if we can reuse these chunk and all output of existed accession to find another
384
+ if links:
385
+ for link in links:
386
+ print(link)
387
+ # if len(all_output) > 1000*1000:
388
+ # all_output = data_preprocess.normalize_for_overlap(all_output)
389
+ # print("after normalizing all output: ", len(all_output))
390
+ if len(data_preprocess.normalize_for_overlap(all_output)) > 600000:
391
+ print("break here")
392
+ break
393
+ if iso != "unknown": query_kw = iso
394
+ else: query_kw = acc
395
+ #text_link, tables_link, final_input_link = data_preprocess.preprocess_document(link,saveLinkFolder, isolate=query_kw)
396
+ success_process, output_process = run_with_timeout(data_preprocess.preprocess_document,args=(link,sample_folder_id),kwargs={"isolate":query_kw,"accession":acc},timeout=180)
397
+ if success_process:
398
+ text_link, tables_link, final_input_link = output_process[0], output_process[1], output_process[2]
399
+ print("yes succeed for process document")
400
+ else: text_link, tables_link, final_input_link = "", "", ""
401
+ context = data_preprocess.extract_context(final_input_link, query_kw)
402
+ if context != "Sample ID not found.":
403
+ if len(data_preprocess.normalize_for_overlap(chunk)) < 1000*1000:
404
+ success_chunk, the_output_chunk = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(chunk, context))
405
+ if success_chunk:
406
+ chunk = the_output_chunk#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
407
+ print("yes succeed for chunk")
408
+ else:
409
+ chunk += context
410
+ print("len context: ", len(context))
411
+ print("basic fall back")
412
+ print("len chunk after: ", len(chunk))
413
+ if len(final_input_link) > 1000*1000:
414
+ if context != "Sample ID not found.":
415
+ final_input_link = context
416
+ else:
417
+ final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
418
+ if len(final_input_link) > 1000 *1000:
419
+ final_input_link = final_input_link[:100000]
420
+ if len(data_preprocess.normalize_for_overlap(all_output)) < 1000*1000:
421
+ success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link))
422
+ if success:
423
+ all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
424
+ print("yes succeed")
425
+ else:
426
+ all_output += final_input_link
427
+ print("len final input: ", len(final_input_link))
428
+ print("basic fall back")
429
+ print("len all output after: ", len(all_output))
430
+ #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
431
+
432
+ else:
433
+ chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
434
+ all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
435
+ if not chunk: chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
436
+ if not all_output: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
437
+ if len(all_output) > 1*1024*1024:
438
+ all_output = data_preprocess.normalize_for_overlap(all_output)
439
+ if len(all_output) > 1*1024*1024:
440
+ all_output = all_output[:1*1024*1024]
441
+ print("chunk len: ", len(chunk))
442
+ print("all output len: ", len(all_output))
443
+ data_preprocess.save_text_to_docx(chunk, file_chunk_path)
444
+ data_preprocess.save_text_to_docx(all_output, file_all_path)
445
+ # Later when saving new files
446
+ # data_preprocess.save_text_to_docx(chunk, chunk_filename, sample_folder_id)
447
+ # data_preprocess.save_text_to_docx(all_output, all_filename, sample_folder_id)
448
+
449
+ # Upload to Drive
450
+ result_chunk_upload = upload_file_to_drive(file_chunk_path, chunk_filename, sample_folder_id)
451
+ result_all_upload = upload_file_to_drive(file_all_path, all_filename, sample_folder_id)
452
+ print("UPLOAD RESULT FOR CHUNK: ", result_chunk_upload)
453
+ print(f"🔗 Uploaded file: https://drive.google.com/file/d/{result_chunk_upload}/view")
454
+ print("here 1")
455
+
456
+ # else:
457
+ # final_input = ""
458
+ # if all_output:
459
+ # final_input = all_output
460
+ # else:
461
+ # if chunk: final_input = chunk
462
+ # #data_preprocess.merge_texts_skipping_overlap(final_input, all_output)
463
+ # if final_input:
464
+ # keywords = []
465
+ # if iso != "unknown": keywords.append(iso)
466
+ # if acc != "unknown": keywords.append(acc)
467
+ # for keyword in keywords:
468
+ # chunkBFS = data_preprocess.get_contextual_sentences_BFS(final_input, keyword)
469
+ # countryDFS, chunkDFS = data_preprocess.get_contextual_sentences_DFS(final_input, keyword)
470
+ # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkDFS)
471
+ # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkBFS)
472
+
473
+ # Define paths for cached RAG assets
474
+ # faiss_index_path = saveLinkFolder+"/faiss_index.bin"
475
+ # document_chunks_path = saveLinkFolder+"/document_chunks.json"
476
+ # structured_lookup_path = saveLinkFolder+"/structured_lookup.json"
477
+ print("here 2")
478
+ faiss_filename = "faiss_index.bin"
479
+ chunks_filename = "document_chunks.json"
480
+ lookup_filename = "structured_lookup.json"
481
+ print("name of faiss: ", faiss_filename)
482
+
483
+ faiss_index_path = os.path.join(LOCAL_TEMP_DIR, faiss_filename)
484
+ document_chunks_path = os.path.join(LOCAL_TEMP_DIR, chunks_filename)
485
+ structured_lookup_path = os.path.join(LOCAL_TEMP_DIR, lookup_filename)
486
+ print("name if faiss path: ", faiss_index_path)
487
+ # 🔥 Remove the local file first if it exists
488
+ faiss_id = find_drive_file(faiss_filename, sample_folder_id)
489
+ document_id = find_drive_file(chunks_filename, sample_folder_id)
490
+ structure_id = find_drive_file(lookup_filename, sample_folder_id)
491
+ if faiss_id and document_id and structure_id:
492
+ print("✅ 3 Files already exist in Google Drive. Downloading them...")
493
+ download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
494
+ download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
495
+ download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
496
+ # Read and parse these into `chunk` and `all_output`
497
+ else:
498
+ if os.path.exists(faiss_index_path):
499
+ os.remove(faiss_index_path)
500
+ if os.path.exists(document_chunks_path):
501
+ os.remove(document_chunks_path)
502
+ if os.path.exists(structured_lookup_path):
503
+ os.remove(structured_lookup_path)
504
+ download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
505
+ download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
506
+ download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
507
+
508
+ print("move to load rag")
509
+ master_structured_lookup, faiss_index, document_chunks = model.load_rag_assets(
510
+ faiss_index_path, document_chunks_path, structured_lookup_path
511
+ )
512
+
513
+ global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest')
514
+ if not all_output:
515
+ if chunk: all_output = chunk
516
+ else: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
517
+ if faiss_index is None:
518
+ print("\nBuilding RAG assets (structured lookup, FAISS index, chunks)...")
519
+ total_doc_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(
520
+ all_output
521
+ ).total_tokens
522
+
523
+ initial_embedding_cost = (total_doc_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
524
+ total_cost_title += initial_embedding_cost
525
+ print(f"Initial one-time embedding cost for '{file_all_path}' ({total_doc_embedding_tokens} tokens): ${initial_embedding_cost:.6f}")
526
+
527
+
528
+ master_structured_lookup, faiss_index, document_chunks, plain_text_content = model.build_vector_index_and_data(
529
+ file_all_path, faiss_index_path, document_chunks_path, structured_lookup_path
530
+ )
531
+ else:
532
+ print("\nRAG assets loaded from file. No re-embedding of entire document will occur.")
533
+ plain_text_content_all, table_strings_all, document_title_all = model.read_docx_text(file_all_path)
534
+ master_structured_lookup['document_title'] = master_structured_lookup.get('document_title', document_title_all)
535
+
536
+ primary_word = iso
537
+ alternative_word = acc
538
+ print(f"\n--- General Query: Primary='{primary_word}' (Alternative='{alternative_word}') ---")
539
+ if features.lower() not in all_output.lower():
540
+ all_output += ". NCBI Features: " + features
541
+ # country, sample_type, method_used, ethnic, spe_loc, total_query_cost = model.query_document_info(
542
+ # primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
543
+ # model.call_llm_api, chunk=chunk, all_output=all_output)
544
+ print("this is chunk for the model")
545
+ print(chunk)
546
+ print("this is all output for the model")
547
+ print(all_output)
548
+ country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
549
+ primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
550
+ model.call_llm_api, chunk=chunk, all_output=all_output)
551
+ print("country using ai: ", country)
552
+ print("sample type using ai: ", sample_type)
553
+ if len(country) == 0: country = "unknown"
554
+ if len(sample_type) == 0: sample_type = "unknown"
555
+ if country_explanation: country_explanation = "-"+country_explanation
556
+ else: country_explanation = ""
557
+ if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
558
+ else: sample_type_explanation = ""
559
+ if method_used == "unknown": method_used = ""
560
+ if country.lower() != "unknown":
561
+ stand_country = standardize_location.smart_country_lookup(country.lower())
562
+ if stand_country.lower() != "not found":
563
+ if stand_country.lower() in acc_score["country"]:
564
+ if country_explanation:
565
+ acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
566
+ else:
567
+ acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
568
+ else:
569
+ if country.lower() in acc_score["country"]:
570
+ if country_explanation:
571
+ if len(method_used + country_explanation) > 0:
572
+ acc_score["country"][country.lower()].append(method_used + country_explanation)
573
+ else:
574
+ if len(method_used + country_explanation) > 0:
575
+ acc_score["country"][country.lower()] = [method_used + country_explanation]
576
+ # if spe_loc.lower() != "unknown":
577
+ # if spe_loc.lower() in acc_score["specific_location"]:
578
+ # acc_score["specific_location"][spe_loc.lower()].append(method_used)
579
+ # else:
580
+ # acc_score["specific_location"][spe_loc.lower()] = [method_used]
581
+ # if ethnic.lower() != "unknown":
582
+ # if ethnic.lower() in acc_score["ethnicity"]:
583
+ # acc_score["ethnicity"][ethnic.lower()].append(method_used)
584
+ # else:
585
+ # acc_score["ethnicity"][ethnic.lower()] = [method_used]
586
+ if sample_type.lower() != "unknown":
587
+ if sample_type.lower() in acc_score["sample_type"]:
588
+ if len(method_used + sample_type_explanation) > 0:
589
+ acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
590
+ else:
591
+ if len(method_used + sample_type_explanation)> 0:
592
+ acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
593
+ # last resort: combine all information to give all output otherwise unknown
594
+ if len(acc_score["country"]) == 0 or len(acc_score["sample_type"]) == 0:
595
+ text = ""
596
+ for key in meta_expand:
597
+ text += str(key) + ": " + meta_expand[key] + "\n"
598
+ if len(data_preprocess.normalize_for_overlap(all_output)) > 0:
599
+ text += data_preprocess.normalize_for_overlap(all_output)
600
+ if len(data_preprocess.normalize_for_overlap(chunk)) > 0:
601
+ text += data_preprocess.normalize_for_overlap(chunk)
602
+ text += ". NCBI Features: " + features
603
+ print("this is text for the last resort model")
604
+ print(text)
605
+ country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
606
+ primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
607
+ model.call_llm_api, chunk=text, all_output=text)
608
+ print("this is last resort results: ")
609
+ print("country: ", country)
610
+ print("sample type: ", sample_type)
611
+ if len(country) == 0: country = "unknown"
612
+ if len(sample_type) == 0: sample_type = "unknown"
613
+ if country_explanation: country_explanation = "-"+country_explanation
614
+ else: country_explanation = ""
615
+ if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
616
+ else: sample_type_explanation = ""
617
+ if method_used == "unknown": method_used = ""
618
+ if country.lower() != "unknown":
619
+ stand_country = standardize_location.smart_country_lookup(country.lower())
620
+ if stand_country.lower() != "not found":
621
+ if stand_country.lower() in acc_score["country"]:
622
+ if country_explanation:
623
+ acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
624
+ else:
625
+ acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
626
+ else:
627
+ if country.lower() in acc_score["country"]:
628
+ if country_explanation:
629
+ if len(method_used + country_explanation) > 0:
630
+ acc_score["country"][country.lower()].append(method_used + country_explanation)
631
+ else:
632
+ if len(method_used + country_explanation) > 0:
633
+ acc_score["country"][country.lower()] = [method_used + country_explanation]
634
+ if sample_type.lower() != "unknown":
635
+ if sample_type.lower() in acc_score["sample_type"]:
636
+ if len(method_used + sample_type_explanation) > 0:
637
+ acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
638
+ else:
639
+ if len(method_used + sample_type_explanation)> 0:
640
+ acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
641
+ end = time.time()
642
+ total_cost_title += total_query_cost
643
+ acc_score["query_cost"] = f"{total_cost_title:.6f}"
644
+ elapsed = end - start
645
+ acc_score["time_cost"] = f"{elapsed:.3f} seconds"
646
+ accs_output[acc] = acc_score
647
+ print(accs_output[acc])
648
+
649
  return accs_output
requirements.txt CHANGED
@@ -1,44 +1,44 @@
1
- biopython==1.85
2
- bs4==0.0.2
3
- gensim==4.3.3
4
- gradio
5
- gspread==6.2.0
6
- gspread-dataframe==4.0.0
7
- huggingface-hub==0.30.2
8
- nltk==3.9.1
9
- oauth2client==4.1.3
10
- openai==1.76.2
11
- openpyxl==3.1.5
12
- pandas==2.2.2
13
- pdfreader==0.1.15
14
- PyMuPDF==1.25.5
15
- pytest==8.3.5
16
- requests==2.32.3
17
- scikit-learn==1.6.1
18
- scipy==1.13.1
19
- spacy==3.8.5
20
- spacy-lookups-data==1.0.5
21
- spire-doc==13.4.6
22
- Spire.Xls==14.12.0
23
- statsmodels==0.14.4
24
- tabula-py==2.10.0
25
- thefuzz==0.22.1
26
- torch
27
- transformers==4.51.3
28
- wordsegment==1.3.1
29
- xlrd==2.0.1
30
- sentence-transformers
31
- lxml
32
- streamlit
33
- requests
34
- google-generativeai
35
- PyPDF2
36
- beautifulsoup4
37
- # For Claude
38
- anthropic
39
- faiss-cpu
40
- python-docx
41
- pycountry
42
- # For Deepseek (If direct DeepseekLLM client library is available, use it.
43
- # Otherwise, 'requests' covers it for simple API calls, but a dedicated client is better for full features)
44
  # deepseek-llm # Uncomment this if Deepseek provides a dedicated pip package for their LLM
 
1
+ biopython==1.85
2
+ bs4==0.0.2
3
+ gensim==4.3.3
4
+ gradio
5
+ gspread==6.2.0
6
+ gspread-dataframe==4.0.0
7
+ huggingface-hub==0.30.2
8
+ nltk==3.9.1
9
+ oauth2client==4.1.3
10
+ openai==1.76.2
11
+ openpyxl==3.1.5
12
+ pandas==2.2.2
13
+ pdfreader==0.1.15
14
+ PyMuPDF==1.25.5
15
+ pytest==8.3.5
16
+ requests==2.32.3
17
+ scikit-learn==1.6.1
18
+ scipy==1.13.1
19
+ spacy==3.8.5
20
+ spacy-lookups-data==1.0.5
21
+ spire-doc==13.4.6
22
+ Spire.Xls==14.12.0
23
+ statsmodels==0.14.4
24
+ tabula-py==2.10.0
25
+ thefuzz==0.22.1
26
+ torch
27
+ transformers==4.51.3
28
+ wordsegment==1.3.1
29
+ xlrd==2.0.1
30
+ sentence-transformers
31
+ lxml
32
+ streamlit
33
+ requests
34
+ google-generativeai
35
+ PyPDF2
36
+ beautifulsoup4
37
+ # For Claude
38
+ anthropic
39
+ faiss-cpu
40
+ python-docx
41
+ pycountry
42
+ # For Deepseek (If direct DeepseekLLM client library is available, use it.
43
+ # Otherwise, 'requests' covers it for simple API calls, but a dedicated client is better for full features)
44
  # deepseek-llm # Uncomment this if Deepseek provides a dedicated pip package for their LLM
smart_fallback.py CHANGED
@@ -1,156 +1,205 @@
1
- from Bio import Entrez, Medline
2
- import model
3
- import mtdna_classifier
4
- # Setup
5
- def fetch_ncbi(accession_number):
6
- Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
7
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
8
- record = Entrez.read(handle)
9
- handle.close()
10
- outputs = {"authors":"unknown",
11
- "institution":"unknown",
12
- "isolate":"unknown",
13
- "definition":"unknown",
14
- "title":"unknown",
15
- "seq_comment":"unknown",
16
- "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
17
- gb_seq = None
18
- try:
19
- # Validate record structure: It should be a list with at least one element (a dict)
20
- if isinstance(record, list) and len(record) > 0:
21
- if isinstance(record[0], dict):
22
- gb_seq = record[0]
23
- else:
24
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
25
- # extract collection date
26
- if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
27
- outputs["collection_date"] = gb_seq["GBSeq_create-date"]
28
- else:
29
- if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
30
- outputs["collection_date"] = gb_seq["GBSeq_update-date"]
31
- # extract definition
32
- if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
33
- outputs["definition"] = gb_seq["GBSeq_definition"]
34
- # extract related-reference things
35
- if "GBSeq_references" in gb_seq:
36
- for ref in gb_seq["GBSeq_references"]:
37
- # extract authors
38
- if "GBReference_authors" in ref and outputs["authors"]=="unknown":
39
- outputs["authors"] = "and ".join(ref["GBReference_authors"])
40
- # extract title
41
- if "GBReference_title" in ref and outputs["title"]=="unknown":
42
- outputs["title"] = ref["GBReference_title"]
43
- # extract submitted journal
44
- if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
45
- outputs["institution"] = ref['GBReference_journal']
46
- # extract seq_comment
47
- if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
48
- outputs["seq_comment"] = gb_seq["GBSeq_comment"]
49
- # extract isolate
50
- if "GBSeq_feature-table" in gb_seq:
51
- if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
52
- for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
53
- if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
54
- outputs["isolate"] = ref["GBQualifier_value"]
55
- else:
56
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
57
-
58
- # If gb_seq is still None, return defaults
59
- if gb_seq is None:
60
- return {"authors":"unknown",
61
- "institution":"unknown",
62
- "isolate":"unknown",
63
- "definition":"unknown",
64
- "title":"unknown",
65
- "seq_comment":"unknown",
66
- "collection_date":"unknown" }
67
- return outputs
68
- except:
69
- print("error in fetching ncbi data")
70
- return {"authors":"unknown",
71
- "institution":"unknown",
72
- "isolate":"unknown",
73
- "definition":"unknown",
74
- "title":"unknown",
75
- "seq_comment":"unknown",
76
- "collection_date":"unknown" }
77
- # Method 1: Smarter Google
78
- def smart_google_queries(metadata: dict):
79
- queries = []
80
-
81
- # Extract useful fields
82
- isolate = metadata.get("isolate")
83
- author = metadata.get("authors")
84
- institution = metadata.get("institution")
85
- title = metadata.get("title")
86
- print(title)
87
- combined = []
88
- # Construct queries
89
- if isolate:
90
- queries.append(f'"{isolate}" mitochondrial DNA')
91
- queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
92
-
93
- if author:
94
- try:
95
- author_name = author.split(',')[0] # Use last name only
96
- except:
97
- author_name = author
98
- queries.append(f'"{author_name}" mitochondrial DNA')
99
- queries.append(f'"{author_name}" mtDNA site:researchgate.net')
100
-
101
- if institution:
102
- try:
103
- short_inst = institution.split(',')[0] # Take first part of institution
104
- except:
105
- short_inst = institution
106
- queries.append(f'"{short_inst}" mtDNA sequence')
107
- queries.append(f'"{short_inst}" isolate site:nature.com')
108
- queries.append(title)
109
- return queries
110
-
111
- def filter_links_by_metadata(search_results):
112
- TRUSTED_DOMAINS = [
113
- "ncbi.nlm.nih.gov",
114
- "pubmed.ncbi.nlm.nih.gov",
115
- "pmc.ncbi.nlm.nih.gov",
116
- "biorxiv.org",
117
- "researchgate.net",
118
- "nature.com",
119
- "sciencedirect.com"
120
- ]
121
- def is_trusted_link(link):
122
- for domain in TRUSTED_DOMAINS:
123
- if domain in link:
124
- return True
125
- return False
126
- def is_relevant_title_snippet(link):
127
- keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
128
- title_snippet = link.lower()
129
- for keyword in keywords:
130
- if keyword in title_snippet:
131
- return True
132
- return False
133
-
134
- filtered = []
135
- if len(search_results) > 0:
136
- for link in search_results:
137
- if is_trusted_link(link) and link not in filtered:
138
- filtered.append(link)
139
- if is_relevant_title_snippet(link) and link not in filtered:
140
- filtered.append(link)
141
- return filtered
142
-
143
- def smart_google_search(metadata):
144
- queries = smart_google_queries(metadata)
145
- links = []
146
- for q in queries:
147
- #print("\n🔍 Query:", q)
148
- results = mtdna_classifier.search_google_custom(q,2)
149
- for link in results:
150
- #print(f"- {link}")
151
- if link not in links:
152
- links.append(link)
153
- filter_links = filter_links_by_metadata(links)
154
- return filter_links
155
- # Method 2: Prompt LLM better or better ai search api with all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # the total information from even ncbi and all search
 
1
+ from Bio import Entrez, Medline
2
+ #import model
3
+ import mtdna_classifier
4
+ from NER.html import extractHTML
5
+ import data_preprocess
6
+ # Setup
7
+ def fetch_ncbi(accession_number):
8
+ try:
9
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
10
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
11
+ record = Entrez.read(handle)
12
+ handle.close()
13
+ outputs = {"authors":"unknown",
14
+ "institution":"unknown",
15
+ "isolate":"unknown",
16
+ "definition":"unknown",
17
+ "title":"unknown",
18
+ "seq_comment":"unknown",
19
+ "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
20
+ gb_seq = None
21
+ # Validate record structure: It should be a list with at least one element (a dict)
22
+ if isinstance(record, list) and len(record) > 0:
23
+ if isinstance(record[0], dict):
24
+ gb_seq = record[0]
25
+ else:
26
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
27
+ # extract collection date
28
+ if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
29
+ outputs["collection_date"] = gb_seq["GBSeq_create-date"]
30
+ else:
31
+ if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
32
+ outputs["collection_date"] = gb_seq["GBSeq_update-date"]
33
+ # extract definition
34
+ if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
35
+ outputs["definition"] = gb_seq["GBSeq_definition"]
36
+ # extract related-reference things
37
+ if "GBSeq_references" in gb_seq:
38
+ for ref in gb_seq["GBSeq_references"]:
39
+ # extract authors
40
+ if "GBReference_authors" in ref and outputs["authors"]=="unknown":
41
+ outputs["authors"] = "and ".join(ref["GBReference_authors"])
42
+ # extract title
43
+ if "GBReference_title" in ref and outputs["title"]=="unknown":
44
+ outputs["title"] = ref["GBReference_title"]
45
+ # extract submitted journal
46
+ if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
47
+ outputs["institution"] = ref['GBReference_journal']
48
+ # extract seq_comment
49
+ if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
50
+ outputs["seq_comment"] = gb_seq["GBSeq_comment"]
51
+ # extract isolate
52
+ if "GBSeq_feature-table" in gb_seq:
53
+ if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
54
+ for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
55
+ if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
56
+ outputs["isolate"] = ref["GBQualifier_value"]
57
+ else:
58
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
59
+
60
+ # If gb_seq is still None, return defaults
61
+ if gb_seq is None:
62
+ return {"authors":"unknown",
63
+ "institution":"unknown",
64
+ "isolate":"unknown",
65
+ "definition":"unknown",
66
+ "title":"unknown",
67
+ "seq_comment":"unknown",
68
+ "collection_date":"unknown" }
69
+ return outputs
70
+ except:
71
+ print("error in fetching ncbi data")
72
+ return {"authors":"unknown",
73
+ "institution":"unknown",
74
+ "isolate":"unknown",
75
+ "definition":"unknown",
76
+ "title":"unknown",
77
+ "seq_comment":"unknown",
78
+ "collection_date":"unknown" }
79
+ # Fallback if NCBI crashed or cannot find accession on NBCI
80
+ def google_accession_search(accession_id):
81
+ """
82
+ Search for metadata by accession ID using Google Custom Search.
83
+ Falls back to known biological databases and archives.
84
+ """
85
+ queries = [
86
+ f"{accession_id}",
87
+ f"{accession_id} site:ncbi.nlm.nih.gov",
88
+ f"{accession_id} site:pubmed.ncbi.nlm.nih.gov",
89
+ f"{accession_id} site:europepmc.org",
90
+ f"{accession_id} site:researchgate.net",
91
+ f"{accession_id} mtDNA",
92
+ f"{accession_id} mitochondrial DNA"
93
+ ]
94
+
95
+ links = []
96
+ for query in queries:
97
+ search_results = mtdna_classifier.search_google_custom(query, 2)
98
+ for link in search_results:
99
+ if link not in links:
100
+ links.append(link)
101
+ return links
102
+
103
+ # Method 1: Smarter Google
104
+ def smart_google_queries(metadata: dict):
105
+ queries = []
106
+
107
+ # Extract useful fields
108
+ isolate = metadata.get("isolate")
109
+ author = metadata.get("authors")
110
+ institution = metadata.get("institution")
111
+ title = metadata.get("title")
112
+ combined = []
113
+ # Construct queries
114
+ if isolate and isolate!="unknown":
115
+ queries.append(f'"{isolate}" mitochondrial DNA')
116
+ queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
117
+
118
+ if author and author!="unknown":
119
+ try:
120
+ author_name = author.split(',')[0] # Use last name only
121
+ except:
122
+ author_name = author
123
+ queries.append(f'"{author_name}" mitochondrial DNA')
124
+ queries.append(f'"{author_name}" mtDNA site:researchgate.net')
125
+
126
+ if institution and institution!="unknown":
127
+ try:
128
+ short_inst = institution.split(',')[0] # Take first part of institution
129
+ except:
130
+ short_inst = institution
131
+ queries.append(f'"{short_inst}" mtDNA sequence')
132
+ queries.append(f'"{short_inst}" isolate site:nature.com')
133
+ if title and title!='unknown':
134
+ if title!="Direct Submission":
135
+ queries.append(title)
136
+ return queries
137
+
138
+ def filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
139
+ TRUSTED_DOMAINS = [
140
+ "ncbi.nlm.nih.gov",
141
+ "pubmed.ncbi.nlm.nih.gov",
142
+ "pmc.ncbi.nlm.nih.gov",
143
+ "biorxiv.org",
144
+ "researchgate.net",
145
+ "nature.com",
146
+ "sciencedirect.com"
147
+ ]
148
+ def is_trusted_link(link):
149
+ for domain in TRUSTED_DOMAINS:
150
+ if domain in link:
151
+ return True
152
+ return False
153
+ def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
154
+ output = []
155
+ keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
156
+ if accession:
157
+ keywords = [accession] + keywords
158
+ title_snippet = link.lower()
159
+ print("save link folder inside this filter function: ", saveLinkFolder)
160
+ article_text = data_preprocess.extract_text(link,saveLinkFolder)
161
+ print("article text")
162
+ print(article_text)
163
+ if link.split(".")[-1].lower():
164
+ if link.split(".")[-1].lower() != "pdf" and link.split(".")[-1].lower() not in "docx" and link.split(".")[-1].lower() not in "xlxs":
165
+ html = extractHTML.HTML("",link)
166
+ jsonSM = html.getSupMaterial()
167
+ if jsonSM: output += sum((jsonSM[key] for key in jsonSM),[])
168
+ for keyword in keywords:
169
+ if keyword.lower() in article_text.lower():
170
+ if link not in output:
171
+ output.append(link)
172
+ print("link and keyword: ", link, keyword)
173
+ return output
174
+ if keyword.lower() in title_snippet.lower():
175
+ if link not in output:
176
+ output.append(link)
177
+ print("link and keyword: ", link, keyword)
178
+ return output
179
+ return output
180
+
181
+ filtered = []
182
+ if len(search_results) > 0:
183
+ for link in search_results:
184
+ if is_trusted_link(link):
185
+ if link not in filtered:
186
+ filtered.append(link)
187
+ else:
188
+ output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
189
+ filtered += output_link
190
+ return filtered
191
+
192
+ def smart_google_search(metadata):
193
+ queries = smart_google_queries(metadata)
194
+ links = []
195
+ for q in queries:
196
+ #print("\n🔍 Query:", q)
197
+ results = mtdna_classifier.search_google_custom(q,2)
198
+ for link in results:
199
+ #print(f"- {link}")
200
+ if link not in links:
201
+ links.append(link)
202
+ #filter_links = filter_links_by_metadata(links)
203
+ return links
204
+ # Method 2: Prompt LLM better or better ai search api with all
205
  # the total information from even ncbi and all search
standardize_location.py CHANGED
@@ -1,83 +1,83 @@
1
- import requests
2
- import re
3
- import os
4
- # Normalize input
5
- def normalize_key(text):
6
- return re.sub(r"[^a-z0-9]", "", text.strip().lower())
7
-
8
- # Search for city/place (normal flow)
9
- def get_country_from_geonames(city_name):
10
- url = os.environ["URL_SEARCHJSON"]
11
- username = os.environ["USERNAME_GEO"]
12
- print("geoname: ", cityname)
13
- params = {
14
- "q": city_name,
15
- "maxRows": 1,
16
- "username": username
17
- }
18
- try:
19
- r = requests.get(url, params=params, timeout=5)
20
- data = r.json()
21
- if data.get("geonames"):
22
- return data["geonames"][0]["countryName"]
23
- except Exception as e:
24
- print("GeoNames searchJSON error:", e)
25
- return None
26
-
27
- # Search for country info using alpha-2/3 codes or name
28
- def get_country_from_countryinfo(input_code):
29
- url = os.environ["URL_COUNTRYJSON"]
30
- username = os.environ["USERNAME_GEO"]
31
- print("countryINFO: ", input_code)
32
- params = {
33
- "username": username
34
- }
35
- try:
36
- r = requests.get(url, params=params, timeout=5)
37
- data = r.json()
38
- if data.get("geonames"):
39
- input_code = input_code.strip().upper()
40
- for country in data["geonames"]:
41
- # Match against country name, country code (alpha-2), iso alpha-3
42
- if input_code in [
43
- country.get("countryName", "").upper(),
44
- country.get("countryCode", "").upper(),
45
- country.get("isoAlpha3", "").upper()
46
- ]:
47
- return country["countryName"]
48
- except Exception as e:
49
- print("GeoNames countryInfoJSON error:", e)
50
- return None
51
-
52
- # Combined smart lookup
53
- def smart_country_lookup(user_input):
54
- raw_input = user_input.strip()
55
- normalized = re.sub(r"[^a-zA-Z0-9]", "", user_input).upper() # normalize for codes (no strip spaces!)
56
- print(raw_input, normalized)
57
- # Special case: if user writes "UK: London" → split and take main country part
58
- if ":" in raw_input:
59
- raw_input = raw_input.split(":")[0].strip() # only take "UK"
60
- # First try as country code (if 2-3 letters or common abbreviation)
61
- if len(normalized) <= 3:
62
- if normalized.upper() in ["UK","U.K","U.K."]:
63
- country = get_country_from_geonames(normalized.upper())
64
- print("get_country_from_geonames(normalized.upper()) ", country)
65
- if country:
66
- return country
67
- else:
68
- country = get_country_from_countryinfo(raw_input)
69
- print("get_country_from_countryinfo(raw_input) ", country)
70
- if country:
71
- return country
72
- print(raw_input)
73
- country = get_country_from_countryinfo(raw_input) # try full names
74
- print("get_country_from_countryinfo(raw_input) ", country)
75
- if country:
76
- return country
77
- # Otherwise, treat as city/place
78
- country = get_country_from_geonames(raw_input)
79
- print("get_country_from_geonames(raw_input) ", country)
80
- if country:
81
- return country
82
-
83
  return "Not found"
 
1
+ import requests
2
+ import re
3
+ import os
4
+ # Normalize input
5
+ def normalize_key(text):
6
+ return re.sub(r"[^a-z0-9]", "", text.strip().lower())
7
+
8
+ # Search for city/place (normal flow)
9
+ def get_country_from_geonames(city_name):
10
+ url = os.environ["URL_SEARCHJSON"]
11
+ username = os.environ["USERNAME_GEO"]
12
+ print("geoname: ", cityname)
13
+ params = {
14
+ "q": city_name,
15
+ "maxRows": 1,
16
+ "username": username
17
+ }
18
+ try:
19
+ r = requests.get(url, params=params, timeout=5)
20
+ data = r.json()
21
+ if data.get("geonames"):
22
+ return data["geonames"][0]["countryName"]
23
+ except Exception as e:
24
+ print("GeoNames searchJSON error:", e)
25
+ return None
26
+
27
+ # Search for country info using alpha-2/3 codes or name
28
+ def get_country_from_countryinfo(input_code):
29
+ url = os.environ["URL_COUNTRYJSON"]
30
+ username = os.environ["USERNAME_GEO"]
31
+ print("countryINFO: ", input_code)
32
+ params = {
33
+ "username": username
34
+ }
35
+ try:
36
+ r = requests.get(url, params=params, timeout=5)
37
+ data = r.json()
38
+ if data.get("geonames"):
39
+ input_code = input_code.strip().upper()
40
+ for country in data["geonames"]:
41
+ # Match against country name, country code (alpha-2), iso alpha-3
42
+ if input_code in [
43
+ country.get("countryName", "").upper(),
44
+ country.get("countryCode", "").upper(),
45
+ country.get("isoAlpha3", "").upper()
46
+ ]:
47
+ return country["countryName"]
48
+ except Exception as e:
49
+ print("GeoNames countryInfoJSON error:", e)
50
+ return None
51
+
52
+ # Combined smart lookup
53
+ def smart_country_lookup(user_input):
54
+ raw_input = user_input.strip()
55
+ normalized = re.sub(r"[^a-zA-Z0-9]", "", user_input).upper() # normalize for codes (no strip spaces!)
56
+ print(raw_input, normalized)
57
+ # Special case: if user writes "UK: London" → split and take main country part
58
+ if ":" in raw_input:
59
+ raw_input = raw_input.split(":")[0].strip() # only take "UK"
60
+ # First try as country code (if 2-3 letters or common abbreviation)
61
+ if len(normalized) <= 3:
62
+ if normalized.upper() in ["UK","U.K","U.K."]:
63
+ country = get_country_from_geonames(normalized.upper())
64
+ print("get_country_from_geonames(normalized.upper()) ", country)
65
+ if country:
66
+ return country
67
+ else:
68
+ country = get_country_from_countryinfo(raw_input)
69
+ print("get_country_from_countryinfo(raw_input) ", country)
70
+ if country:
71
+ return country
72
+ print(raw_input)
73
+ country = get_country_from_countryinfo(raw_input) # try full names
74
+ print("get_country_from_countryinfo(raw_input) ", country)
75
+ if country:
76
+ return country
77
+ # Otherwise, treat as city/place
78
+ country = get_country_from_geonames(raw_input)
79
+ print("get_country_from_geonames(raw_input) ", country)
80
+ if country:
81
+ return country
82
+
83
  return "Not found"