Spaces:
Running
Running
/** | |
* HuggingFace Dataset Viewer API wrapper | |
* Handles fetching data from the datasets-server API with caching and error handling | |
*/ | |
class DatasetAPI { | |
constructor() { | |
this.baseURL = 'https://datasets-server.huggingface.co'; | |
this.cache = new Map(); | |
this.cacheExpiry = 45 * 60 * 1000; // 45 minutes (conservative for signed URLs) | |
this.rowsPerFetch = 100; // API maximum | |
} | |
/** | |
* Check if a dataset is valid and has viewer enabled | |
*/ | |
async validateDataset(datasetId) { | |
try { | |
const response = await fetch(`${this.baseURL}/is-valid?dataset=${encodeURIComponent(datasetId)}`); | |
if (!response.ok) { | |
throw new Error(`Failed to validate dataset: ${response.statusText}`); | |
} | |
const data = await response.json(); | |
if (!data.viewer) { | |
throw new Error('Dataset viewer is not available for this dataset'); | |
} | |
return true; | |
} catch (error) { | |
throw new Error(`Dataset validation failed: ${error.message}`); | |
} | |
} | |
/** | |
* Get dataset info including splits and configs | |
*/ | |
async getDatasetInfo(datasetId) { | |
const cacheKey = `info_${datasetId}`; | |
const cached = this.getFromCache(cacheKey); | |
if (cached) return cached; | |
try { | |
const response = await fetch(`${this.baseURL}/splits?dataset=${encodeURIComponent(datasetId)}`); | |
if (!response.ok) { | |
throw new Error(`Failed to get dataset info: ${response.statusText}`); | |
} | |
const data = await response.json(); | |
// Extract the default config and split | |
const defaultConfig = data.splits[0]?.config || 'default'; | |
const defaultSplit = data.splits.find(s => s.split === 'train')?.split || data.splits[0]?.split || 'train'; | |
const info = { | |
configs: [...new Set(data.splits.map(s => s.config))], | |
splits: [...new Set(data.splits.map(s => s.split))], | |
defaultConfig, | |
defaultSplit, | |
raw: data | |
}; | |
this.setCache(cacheKey, info); | |
return info; | |
} catch (error) { | |
throw new Error(`Failed to get dataset info: ${error.message}`); | |
} | |
} | |
/** | |
* Get the total number of rows in a dataset | |
*/ | |
async getTotalRows(datasetId, config, split) { | |
const cacheKey = `size_${datasetId}_${config}_${split}`; | |
const cached = this.getFromCache(cacheKey); | |
if (cached) return cached; | |
try { | |
// First try to get from the size endpoint | |
const sizeResponse = await fetch( | |
`${this.baseURL}/size?dataset=${encodeURIComponent(datasetId)}&config=${encodeURIComponent(config)}&split=${encodeURIComponent(split)}` | |
); | |
if (sizeResponse.ok) { | |
const sizeData = await sizeResponse.json(); | |
// The API returns num_rows in size.config or size.splits[0] | |
const size = sizeData.size?.config?.num_rows || | |
sizeData.size?.splits?.[0]?.num_rows || | |
0; | |
this.setCache(cacheKey, size); | |
return size; | |
} | |
// Fallback: get first rows and check num_rows_total | |
const rowsResponse = await fetch( | |
`${this.baseURL}/first-rows?dataset=${encodeURIComponent(datasetId)}&config=${encodeURIComponent(config)}&split=${encodeURIComponent(split)}` | |
); | |
if (!rowsResponse.ok) { | |
throw new Error('Unable to determine dataset size'); | |
} | |
const rowsData = await rowsResponse.json(); | |
const size = rowsData.num_rows_total || rowsData.rows?.length || 0; | |
this.setCache(cacheKey, size); | |
return size; | |
} catch (error) { | |
console.warn('Failed to get total rows:', error); | |
return null; | |
} | |
} | |
/** | |
* Fetch rows from the dataset | |
*/ | |
async fetchRows(datasetId, config, split, offset, length = this.rowsPerFetch) { | |
const cacheKey = `rows_${datasetId}_${config}_${split}_${offset}_${length}`; | |
const cached = this.getFromCache(cacheKey); | |
if (cached) return cached; | |
try { | |
const response = await fetch( | |
`${this.baseURL}/rows?dataset=${encodeURIComponent(datasetId)}&config=${encodeURIComponent(config)}&split=${encodeURIComponent(split)}&offset=${offset}&length=${length}` | |
); | |
if (!response.ok) { | |
if (response.status === 403) { | |
throw new Error('Access denied. This dataset may be private or gated.'); | |
} | |
throw new Error(`Failed to fetch rows: ${response.statusText}`); | |
} | |
const data = await response.json(); | |
// Extract column information | |
const columns = this.detectColumns(data.features, data.rows[0]?.row); | |
const result = { | |
rows: data.rows, | |
features: data.features, | |
columns, | |
numRowsTotal: data.num_rows_total, | |
partial: data.partial || false | |
}; | |
this.setCache(cacheKey, result); | |
return result; | |
} catch (error) { | |
throw new Error(`Failed to fetch rows: ${error.message}`); | |
} | |
} | |
/** | |
* Get a single row by index with smart batching | |
*/ | |
async getRow(datasetId, config, split, index) { | |
// Calculate which batch this index falls into | |
const batchStart = Math.floor(index / this.rowsPerFetch) * this.rowsPerFetch; | |
const batchData = await this.fetchRows(datasetId, config, split, batchStart, this.rowsPerFetch); | |
const localIndex = index - batchStart; | |
if (localIndex >= 0 && localIndex < batchData.rows.length) { | |
return { | |
row: batchData.rows[localIndex].row, | |
columns: batchData.columns, | |
numRowsTotal: batchData.numRowsTotal | |
}; | |
} | |
throw new Error(`Row ${index} not found`); | |
} | |
/** | |
* Detect column names for image and text data | |
*/ | |
detectColumns(features, sampleRow) { | |
let imageColumn = null; | |
let originalTextColumn = null; | |
let improvedTextColumn = null; | |
let inferenceInfoColumn = null; | |
// Try to detect from features first | |
for (const feature of features || []) { | |
const name = feature.name; | |
const type = feature.type; | |
// Detect image column | |
if (type._type === 'Image' || type.dtype === 'image' || type.feature?._type === 'Image') { | |
imageColumn = name; | |
} | |
// Detect text columns based on common patterns | |
if (!originalTextColumn && ['text', 'ocr', 'original_text', 'original', 'ground_truth'].includes(name)) { | |
originalTextColumn = name; | |
} | |
if (!improvedTextColumn && ['markdown', 'new_ocr', 'corrected_text', 'improved', 'vlm_ocr', 'corrected', 'rolmocr_text'].includes(name)) { | |
improvedTextColumn = name; | |
} | |
// Detect inference info column | |
if (name === 'inference_info') { | |
inferenceInfoColumn = name; | |
} | |
} | |
// Fallback: detect from sample row | |
if (sampleRow) { | |
const keys = Object.keys(sampleRow); | |
if (!imageColumn) { | |
for (const key of keys) { | |
if (sampleRow[key]?.src && sampleRow[key]?.height !== undefined) { | |
imageColumn = key; | |
break; | |
} | |
} | |
} | |
// Additional text column detection from row data | |
if (!originalTextColumn) { | |
const candidates = ['text', 'ocr', 'original_text', 'original']; | |
originalTextColumn = keys.find(k => candidates.includes(k)) || null; | |
} | |
if (!improvedTextColumn) { | |
const candidates = ['markdown', 'new_ocr', 'corrected_text', 'improved', 'rolmocr_text']; | |
improvedTextColumn = keys.find(k => candidates.includes(k)) || null; | |
} | |
// Check for inference info in sample row | |
if (!inferenceInfoColumn && keys.includes('inference_info')) { | |
inferenceInfoColumn = 'inference_info'; | |
} | |
} | |
return { | |
image: imageColumn, | |
originalText: originalTextColumn, | |
improvedText: improvedTextColumn, | |
inferenceInfo: inferenceInfoColumn | |
}; | |
} | |
/** | |
* Refresh expired image URL by re-fetching the row | |
*/ | |
async refreshImageUrl(datasetId, config, split, index) { | |
// Clear cache for this specific row batch | |
const batchStart = Math.floor(index / this.rowsPerFetch) * this.rowsPerFetch; | |
const cacheKey = `rows_${datasetId}_${config}_${split}_${batchStart}_${this.rowsPerFetch}`; | |
this.cache.delete(cacheKey); | |
// Re-fetch the row | |
return await this.getRow(datasetId, config, split, index); | |
} | |
/** | |
* Cache management utilities | |
*/ | |
getFromCache(key) { | |
const cached = this.cache.get(key); | |
if (!cached) return null; | |
if (Date.now() - cached.timestamp > this.cacheExpiry) { | |
this.cache.delete(key); | |
return null; | |
} | |
return cached.data; | |
} | |
setCache(key, data) { | |
this.cache.set(key, { | |
data, | |
timestamp: Date.now() | |
}); | |
} | |
clearCache() { | |
this.cache.clear(); | |
} | |
/** | |
* Parse inference info JSON safely | |
*/ | |
parseInferenceInfo(inferenceInfoData) { | |
if (!inferenceInfoData) return null; | |
try { | |
// Handle if it's already an object (some datasets might store it as object) | |
if (typeof inferenceInfoData === 'object' && !Array.isArray(inferenceInfoData)) { | |
return inferenceInfoData; | |
} | |
// Handle if it's a JSON string | |
if (typeof inferenceInfoData === 'string') { | |
const parsed = JSON.parse(inferenceInfoData); | |
// If it's an array, take the first item | |
if (Array.isArray(parsed) && parsed.length > 0) { | |
return parsed[0]; | |
} | |
return parsed; | |
} | |
// Handle if it's already an array | |
if (Array.isArray(inferenceInfoData) && inferenceInfoData.length > 0) { | |
return inferenceInfoData[0]; | |
} | |
return null; | |
} catch (error) { | |
console.warn('Failed to parse inference info:', error); | |
return null; | |
} | |
} | |
} | |
// Export for use in other scripts | |
window.DatasetAPI = DatasetAPI; |