File size: 4,080 Bytes
7e4123a
 
 
 
046ea77
7e4123a
 
 
 
046ea77
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1be7f01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e4123a
1be7f01
 
 
 
 
 
 
 
 
 
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
046ea77
 
 
 
 
1be7f01
 
 
 
046ea77
 
 
 
 
15c9d8a
046ea77
 
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import copy
import datasets
import pandas as pd
from datasets import Dataset
from collections import defaultdict

from datetime import datetime, timedelta
from background import process_arxiv_ids
from utils import create_hf_hub
from apscheduler.schedulers.background import BackgroundScheduler

def _count_nans(row):
    count = 0

    for _, (k, v) in enumerate(row.items()):
        if v is None:
            count = count + 1

    return count

def _initialize_requested_arxiv_ids(request_ds):
    requested_arxiv_ids = []

    for request_d in request_ds['train']:
        arxiv_ids = request_d['Requested arXiv IDs']
        requested_arxiv_ids = requested_arxiv_ids + arxiv_ids
    
    requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids})
    return requested_arxiv_ids_df

def _initialize_paper_info(source_ds):
    title2qna, date2qna = {}, {}
    date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    arxivid2data = {}
    count = 0    

    if len(source_ds["train"]) > 1:
        for data in source_ds["train"]:
                date = data["target_date"].strftime("%Y-%m-%d")
                arxiv_id = data["arxiv_id"]

                if date in date2qna:
                    papers = copy.deepcopy(date2qna[date])
                    for paper in papers:
                        if paper["title"] == data["title"]:
                            if _count_nans(paper) > _count_nans(data):
                                date2qna[date].remove(paper)
                    
                    date2qna[date].append(data)
                    del papers
                else:
                    date2qna[date] = [data]

        for date in date2qna:
            year, month, day = date.split("-")
            papers = date2qna[date]
            for paper in papers:
                title2qna[paper["title"]] = paper
                arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper}
                date_dict[year][month][day].append(paper)

        titles = [f"[{v['arxiv_id']}] {k}" for k, v in title2qna.items()]
        return titles, date_dict, arxivid2data
    else:
        return [], {}, {}

def initialize_data(source_data_repo_id, request_data_repo_id, empty_src_dataset):
    global date_dict, arxivid2data
    global requested_arxiv_ids_df

    source_ds = datasets.load_dataset(source_data_repo_id)
    request_ds = datasets.load_dataset(request_data_repo_id)
    
    titles, date_dict, arxivid2data = _initialize_paper_info(source_ds)
    requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds)

    return (
        titles, date_dict, requested_arxiv_ids_df, arxivid2data
    )

def update_dataframe(request_data_repo_id):
    request_ds = datasets.load_dataset(request_data_repo_id)
    return _initialize_requested_arxiv_ids(request_ds)

def initialize_repos(
    source_data_repo_id, request_data_repo_id, hf_token
):
    if create_hf_hub(source_data_repo_id, hf_token) is False:
        print(f"{source_data_repo_id} repository already exists")
    else:
        dummy_row = {"title": ["dummy"]}
        ds = Dataset.from_dict(dummy_row)
        ds.push_to_hub(source_data_repo_id, token=hf_token)

    if create_hf_hub(request_data_repo_id, hf_token) is False:
        print(f"{request_data_repo_id} repository already exists")
    else:
        df = pd.DataFrame(data={"Requested arXiv IDs": [["top"]]})
        ds = Dataset.from_pandas(df)
        ds.push_to_hub(request_data_repo_id, token=hf_token)

def get_secrets():
    global gemini_api_key
    global hf_token
    global request_arxiv_repo_id
    global dataset_repo_id

    gemini_api_key = os.getenv("GEMINI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")
    dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID") 
    request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID")
    restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa")

    return (
        gemini_api_key,
        hf_token,
        dataset_repo_id,
        request_arxiv_repo_id,
        restart_repo_id
    )