Yoad commited on
Commit
f59869c
·
1 Parent(s): e641911

Dataset Preview - Initial commit

Browse files
Files changed (6) hide show
  1. .gitignore +13 -0
  2. .python-version +1 -0
  3. app.py +218 -0
  4. pyproject.toml +13 -0
  5. requirements.txt +4 -0
  6. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python venvs
2
+ .venv
3
+
4
+ # env files
5
+ .env
6
+
7
+ # local streamlit state
8
+ .streamlit/
9
+
10
+ # pycache
11
+ __pycache__/
12
+ *.py[cod]
13
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.9
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+
6
+ import pandas as pd
7
+ import streamlit as st
8
+ from huggingface_hub import HfApi
9
+
10
+ st.set_page_config(page_title="Knesset Plenums Dataset Preview", layout="wide")
11
+
12
+ fallback_dataset_repo_owner = os.environ.get("REPO_OWNER", "ivrit.ai")
13
+ dataset_repo_owner = os.environ.get("SPACE_AUTHOR_NAME", fallback_dataset_repo_owner)
14
+ dataset_repo_name = os.environ.get("DATASET_REPO_NAME", "knesset-plenums")
15
+ repo_id = f"{dataset_repo_owner}/{dataset_repo_name}"
16
+
17
+ hf_api = HfApi(token=st.secrets["HF_TOKEN"])
18
+
19
+ manifest_file = hf_api.hf_hub_download(repo_id, "manifest.csv", repo_type="dataset")
20
+
21
+ manifest_df = pd.read_csv(manifest_file)
22
+
23
+ # Filter samples with duration less than 7200 seconds (2 hours)
24
+ filtered_samples = manifest_df[manifest_df["duration"] < 7200].copy()
25
+
26
+ # Convert duration from seconds to hours for display
27
+ filtered_samples["duration_hours"] = filtered_samples["duration"] / 3600
28
+
29
+ # Create display options for the dropdown
30
+ sample_options = {}
31
+ for _, row in filtered_samples.iterrows():
32
+ plenum_id = str(row["plenum_id"])
33
+ plenum_date = row["plenum_date"]
34
+ hours = round(row["duration_hours"], 1)
35
+ display_text = f"{plenum_date} - ({hours} hours)"
36
+ sample_options[display_text] = plenum_id
37
+
38
+ # Default to sample_id 81733 if available, otherwise use the first sample
39
+ default_sample_id = "81733"
40
+ default_option = next(
41
+ (k for k, v in sample_options.items() if v == default_sample_id),
42
+ next(iter(sample_options.keys())) if sample_options else None,
43
+ )
44
+
45
+ # Create the dropdown for sample selection
46
+ selected_option = st.sidebar.selectbox(
47
+ "Select a plenum sample:",
48
+ options=list(sample_options.keys()),
49
+ index=list(sample_options.keys()).index(default_option) if default_option else 0,
50
+ )
51
+
52
+ # Get the selected plenum ID
53
+ sample_plenum_id = sample_options[selected_option]
54
+ sample_audio_file_repo_path = f"{sample_plenum_id}/audio.m4a"
55
+ sample_metadata_file_repo_path = f"{sample_plenum_id}/metadata.json"
56
+ sample_aligned_file_repo_path = f"{sample_plenum_id}/transcript.aligned.json"
57
+ sample_raw_text_repo_path = f"{sample_plenum_id}/raw.transcript.txt"
58
+
59
+
60
+ # Display the title with the selected Plenum ID
61
+ st.title(f"Knesset Plenum ID: {sample_plenum_id}")
62
+
63
+
64
+ # Cache the sample data loading to only reload when the sample changes
65
+ @st.cache_data
66
+ def load_sample_data(repo_id, plenum_id):
67
+ """Load sample data files for a given plenum ID"""
68
+ audio_path = f"{plenum_id}/audio.m4a"
69
+ metadata_path = f"{plenum_id}/metadata.json"
70
+ transcript_path = f"{plenum_id}/transcript.aligned.json"
71
+
72
+ audio_file = hf_api.hf_hub_download(repo_id, audio_path, repo_type="dataset")
73
+ metadata_file = hf_api.hf_hub_download(repo_id, metadata_path, repo_type="dataset")
74
+ transcript_file = hf_api.hf_hub_download(
75
+ repo_id, transcript_path, repo_type="dataset"
76
+ )
77
+ raw_transcript_text_file = hf_api.hf_hub_download(
78
+ repo_id, sample_raw_text_repo_path, repo_type="dataset"
79
+ )
80
+
81
+ return audio_file, metadata_file, transcript_file, raw_transcript_text_file
82
+
83
+
84
+ # Load the sample data for the selected plenum
85
+ (
86
+ sample_audio_file,
87
+ sample_metadata_file,
88
+ sample_transcript_aligned_file,
89
+ sample_raw_transcript_text_file,
90
+ ) = load_sample_data(repo_id, sample_plenum_id)
91
+
92
+ # Parses the metadata file of this sample - to get the list of all segments.
93
+ with open(sample_metadata_file, "r") as f:
94
+ sample_metadata = json.load(f)
95
+
96
+ # each segment is a dict with the structure:
97
+ # {
98
+ # "start": 3527.26,
99
+ # "end": 3531.53,
100
+ # "probability": 0.9309
101
+ # },
102
+ segments_quality_scores = sample_metadata["per_segment_quality_scores"]
103
+ segments_quality_scores_df = pd.DataFrame(segments_quality_scores)
104
+ segments_quality_scores_df["segment_id"] = segments_quality_scores_df.index
105
+
106
+ with open(sample_transcript_aligned_file, "r") as f:
107
+ sample_transcript_aligned = json.load(f)
108
+ transcript_segments = sample_transcript_aligned["segments"]
109
+
110
+ with open(sample_raw_transcript_text_file, "r") as f:
111
+ sample_raw_text = f.read()
112
+
113
+ col_main, col_aux = st.columns([2, 3])
114
+
115
+ event = col_main.dataframe(
116
+ segments_quality_scores_df,
117
+ on_select="rerun",
118
+ hide_index=True,
119
+ selection_mode=["single-row"],
120
+ column_config={
121
+ "probability": st.column_config.ProgressColumn(
122
+ label="Quality Score",
123
+ width="medium",
124
+ format="percent",
125
+ min_value=0,
126
+ max_value=1,
127
+ )
128
+ },
129
+ )
130
+
131
+
132
+ # Initialize session state for selection if it doesn't exist
133
+ if "default_selection" not in st.session_state:
134
+ st.session_state.default_selection = random.randint(
135
+ 0, min(49, len(segments_quality_scores_df) - 1)
136
+ )
137
+
138
+ # If a selection exists, get the start and end times of the selected segment
139
+ if event and event.selection and event.selection["rows"]:
140
+ row_idx = event.selection["rows"][0]
141
+ else:
142
+ # Use the default random selection if no row is selected
143
+ row_idx = st.session_state.default_selection
144
+
145
+ df_row = segments_quality_scores_df.iloc[row_idx]
146
+ segment_id = int(df_row["segment_id"])
147
+ selected_segment = segments_quality_scores[segment_id]
148
+ start_time = selected_segment["start"]
149
+ end_time = selected_segment["end"]
150
+
151
+ with col_main:
152
+ st.write(f"Selected segment: {selected_segment}")
153
+ start_at = selected_segment["start"]
154
+ end_at = selected_segment["end"]
155
+
156
+ st.audio(
157
+ sample_audio_file,
158
+ start_time=math.floor(start_at),
159
+ end_time=math.ceil(end_at),
160
+ autoplay=True,
161
+ )
162
+ transcript_segment = transcript_segments[segment_id]
163
+ st.caption(f'<div dir="rtl">{transcript_segment["text"]}</div>', unsafe_allow_html=True)
164
+ st.divider()
165
+ st.caption(
166
+ f"Note: The audio will start at {math.floor(start_at)} seconds and end at {math.ceil(end_at)} seconds (rounded up/down) since this is the resolution of the player, actual segments are more accurate."
167
+ )
168
+
169
+
170
+ with col_aux:
171
+ # Create a chart of Quality vs start_time
172
+ st.subheader("Segment Quality Over Time")
173
+
174
+ # Prepare data for the chart
175
+ chart_data = segments_quality_scores_df.copy()
176
+ chart_data = chart_data.sort_values(by="start")
177
+
178
+ # Add a scatter plot to highlight the selected segment
179
+ import altair as alt
180
+ import pandas as pd
181
+
182
+ # Create a base chart with all points
183
+ base_chart = alt.Chart(chart_data).mark_circle(size=20).encode(
184
+ x=alt.X('start:Q', title='Start Time (seconds)'),
185
+ y=alt.Y('probability:Q', title='Quality Score', scale=alt.Scale(domain=[0, 1])),
186
+ tooltip=['start', 'end', 'probability']
187
+ )
188
+
189
+ # Create a highlight for the selected segment
190
+ selected_point = pd.DataFrame([{
191
+ 'start': selected_segment['start'],
192
+ 'probability': selected_segment['probability']
193
+ }])
194
+
195
+ highlight = alt.Chart(selected_point).mark_circle(size=120, color='red').encode(
196
+ x='start:Q',
197
+ y='probability:Q'
198
+ )
199
+
200
+ # Combine the charts
201
+ combined_chart = base_chart + highlight
202
+
203
+ # Display the chart
204
+ st.altair_chart(combined_chart, use_container_width=True)
205
+
206
+ with st.expander("Raw Transcript Text", expanded=False):
207
+ st.text_area(
208
+ "Raw Transcript Text",
209
+ value=sample_raw_text,
210
+ height=300,
211
+ label_visibility="collapsed",
212
+ disabled=True,
213
+ )
214
+
215
+ with st.expander("Sample Metadata", expanded=False):
216
+ st.json(
217
+ sample_metadata
218
+ )
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "knesset-plenums-preview"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11.9"
7
+ dependencies = [
8
+ "datasets>=3.5.0",
9
+ "huggingface-hub>=0.30.2",
10
+ "pandas>=2.2.3",
11
+ "python-dotenv>=1.1.0",
12
+ "streamlit==1.44.1",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.44.1
2
+ datasets
3
+ huggingface_hub
4
+ pandas
uv.lock ADDED
The diff for this file is too large to render. See raw diff