File size: 3,030 Bytes
64ae4c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
{
"cells": [
{
"cell_type": "markdown",
"id": "e66bbb77-71f5-4d80-b766-f67144ea7a93",
"metadata": {},
"source": [
"# Data Records\n",
"\n",
"## This notebook generates the data_records.json file where each entry in the resulting dictionary follows the form {filename: num_records} for every dataset we will use during training"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "74ad6613-44ff-435e-8550-df993e915677",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# import relevant libraries\n",
"import os\n",
"import boto3\n",
"import json\n",
"from smart_open import open"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2d53761-da0e-44f4-8a3e-1285bf810b03",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"s3 = boto3.resource('s3')\n",
"my_bucket = s3.Bucket('lodestone-rnd')\n",
"\n",
"# collect all filenames from the data/ directory of the lodestone-rnd S3 bucket\n",
"files = [\"\"]*((621+12+9+36)+1)\n",
"for i, object_summary in enumerate(my_bucket.objects.filter(Prefix=\"data/\")):\n",
" files[i] = object_summary.key[5:]\n",
"files = files[1:]\n",
"files = [file for file in files if file != 'cnn_dailymail_splitted.json.gz']\n",
"\n",
"s3_client = boto3.client(\"s3\")\n",
"\n",
"# for each training dataset, store the number of records in a dictionary with the following form {filename: num_records}\n",
"data_lengths = {}\n",
"for file in files:\n",
" source_uri = f's3://lodestone-rnd/data/{file}'\n",
" # S2ORC_citations_abstracts.json.gz and amazon-qa.json.gz must be handled differently since each line in their training\n",
" # data is split into multiple records due to the fact that each query has multiple positive pair responses\n",
" if file in ['S2ORC_citations_abstracts.json.gz','amazon-qa.json.gz']:\n",
" length = 0\n",
" for json_line in open(source_uri, transport_params={\"client\": s3_client}):\n",
" data = json.loads(json_line.strip())\n",
" length += len(data['pos'])\n",
" else:\n",
" length = int(os.popen(f'aws s3 cp {source_uri} - | zcat | wc -l').read().rstrip())\n",
" data_lengths[f'{file}'] = length\n",
" \n",
"# write the resulting dictionary to a .json file for future use during training\n",
"with open('data_records.json', 'w') as fileout:\n",
" json.dump(data_lengths, fileout)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p310",
"language": "python",
"name": "conda_pytorch_p310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|