akshayka's picture
Update notebook.py
1e1c71b verified
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "marimo",
# "matplotlib==3.10.1",
# "numpy==2.2.3",
# ]
# ///
import marimo
__generated_with = "0.11.20"
app = marimo.App()
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
# Finding $\pi$ in colliding blocks
One of the remarkable things about mathematical constants like $\pi$ is how frequently they arise in nature, in the most surprising of places.
Inspired by 3Blue1Brown, this [marimo notebook](https://github.com/marimo-team/marimo) shows how the number of collisions incurred in a particular system involving two blocks converges to the digits in $\pi$.
**Tip!**: Use the menu in the top right to reveal the notebook's code.
"""
)
return
@app.cell(hide_code=True)
def _(mo):
slider = mo.ui.slider(start=0, stop=3, value=3, show_value=True)
return (slider,)
@app.cell(hide_code=True)
def _(mo, slider):
mo.md("## Simulate!")
return
@app.cell(hide_code=True)
def _(mo, slider):
mo.md(f"Use this slider to control the weight of the heavier block: {slider}")
return
@app.cell(hide_code=True)
def _(mo, slider):
mo.md(rf"The heavier block weighs **$100^{{ {slider.value} }}$** kg.")
return
@app.cell(hide_code=True)
def _(mo):
run_button = mo.ui.run_button(label="Run simulation!")
run_button.right()
return (run_button,)
@app.cell
def _(run_button, simulate_collisions, slider):
if run_button.value:
mass_ratio = 100**slider.value
_, ani, collisions = simulate_collisions(
mass_ratio, total_time=15, dt=0.001
)
return ani, collisions, mass_ratio
@app.cell
def _(ani, mo, run_button):
video = None
if run_button.value:
with mo.status.spinner(title="Rendering collision video ..."):
video = mo.Html(ani.to_html5_video())
video
return (video,)
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## The 3Blue1Brown video
If you haven't seen it, definitely check out the video that inspired this notebook:
"""
)
return
@app.cell(hide_code=True)
def _(mo):
mo.accordion(
{
"🎥 Watch the video": mo.Html(
'<iframe width="700" height="400" src="https://www.youtube.com/embed/6dTyOl1fmDo?si=xl9v6Y8x2e3r3A9I" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>'
)
})
return
@app.cell
def _():
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle
return Rectangle, animation, np, plt
@app.cell
def _():
class Block:
def __init__(self, mass, velocity, position, size=1.0):
self.mass = mass
self.velocity = velocity
self.position = position
self.size = size
def update(self, dt):
self.position += self.velocity * dt
def collide(self, other):
# Calculate velocities after elastic collision
m1, m2 = self.mass, other.mass
v1, v2 = self.velocity, other.velocity
new_v1 = (m1 - m2) / (m1 + m2) * v1 + (2 * m2) / (m1 + m2) * v2
new_v2 = (2 * m1) / (m1 + m2) * v1 + (m2 - m1) / (m1 + m2) * v2
self.velocity = new_v1
other.velocity = new_v2
return 1 # Return 1 collision
return (Block,)
@app.cell
def check_collisions():
def check_collisions(small_block, big_block, wall_pos=0):
collisions = 0
# Check for collision between blocks
if small_block.position + small_block.size > big_block.position:
small_block.position = big_block.position - small_block.size
collisions += small_block.collide(big_block)
# Check for collision with the wall
if small_block.position < wall_pos:
small_block.position = wall_pos
small_block.velocity *= -1
collisions += 1
return collisions
return (check_collisions,)
@app.cell
def _(Block, check_collisions, create_animation):
def simulate_collisions(mass_ratio, total_time=15, dt=0.001, animate=True):
# Initialize blocks
small_block = Block(mass=1, velocity=0, position=2)
big_block = Block(mass=mass_ratio, velocity=-0.5, position=4)
# Simulation variables
time = 0
collision_count = 0
# For animation
times = []
small_positions = []
big_positions = []
collision_counts = []
# Run simulation
while time < total_time:
# Update positions
small_block.update(dt)
big_block.update(dt)
# Check for and handle collisions
new_collisions = check_collisions(small_block, big_block)
collision_count += new_collisions
# Store data for animation
times.append(time)
small_positions.append(small_block.position)
big_positions.append(big_block.position)
collision_counts.append(collision_count)
time += dt
print(f"Mass ratio: {mass_ratio}, Total collisions: {collision_count}")
if animate:
axis, ani = create_animation(
times, small_positions, big_positions, collision_counts, mass_ratio
)
else:
axis, ani = None
return axis, ani, collision_count
return (simulate_collisions,)
@app.cell
def _(Rectangle, animation, plt):
def create_animation(
times, small_positions, big_positions, collision_counts, mass_ratio
):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
# Setup for blocks visualization
ax1.set_xlim(-1, 10)
ax1.set_ylim(-1, 2)
ax1.set_xlabel("Position")
ax1.set_title(f"Block Collisions (Mass Ratio = {mass_ratio})")
wall = plt.Line2D([0, 0], [-1, 2], color="black", linewidth=3)
ax1.add_line(wall)
small_block = Rectangle((small_positions[0], 0), 1, 1, color="blue")
big_block = Rectangle((big_positions[0], 0), 1, 1, color="red")
ax1.add_patch(small_block)
ax1.add_patch(big_block)
# Add weight labels for each block
small_label = ax1.text(
small_positions[0] + 0.5,
1.2,
f"{1}kg",
ha="center",
va="center",
color="blue",
fontweight="bold",
)
big_label = ax1.text(
big_positions[0] + 0.5,
1.2,
f"{mass_ratio}kg",
ha="center",
va="center",
color="red",
fontweight="bold",
)
# Setup for collision count
ax2.set_xlim(0, times[-1])
# ax2.set_ylim(0, collision_counts[-1] * 1.1)
ax2.set_ylim(0, collision_counts[-1] * 1.1)
ax2.set_xlabel("Time")
ax2.set_ylabel("# Collisions:")
ax2.set_yscale("symlog")
(collision_line,) = ax2.plot([], [], "g-")
# Add text for collision count
collision_text = ax2.text(
0.02, 0.9, "", transform=ax2.transAxes, fontsize="x-large"
)
def init():
small_block.set_xy((small_positions[0], 0))
big_block.set_xy((big_positions[0], 0))
small_label.set_position((small_positions[0] + 0.5, 1.2))
big_label.set_position((big_positions[0] + 0.5, 1.2))
collision_line.set_data([], [])
collision_text.set_text("")
return small_block, big_block, collision_line, collision_text
frame_step = 300
def animate(i):
# Speed up animation but ensure we reach the final frame
frame_index = min(i * frame_step, len(times) - 1)
small_block.set_xy((small_positions[frame_index], 0))
big_block.set_xy((big_positions[frame_index], 0))
# Update the weight labels to follow the blocks
small_label.set_position((small_positions[frame_index] + 0.5, 1.2))
big_label.set_position((big_positions[frame_index] + 0.5, 1.2))
# Show data up to the current frame
collision_line.set_data(
times[: frame_index + 1], collision_counts[: frame_index + 1]
)
# For the last frame, show the final collision count
if frame_index >= len(times) - 1:
collision_text.set_text(
f"# Collisions: {collision_counts[-1]}"
)
else:
collision_text.set_text(
f"# Collisions: {collision_counts[frame_index]}"
)
return (
small_block,
big_block,
small_label,
big_label,
collision_line,
collision_text,
)
plt.tight_layout()
frames = max(1, len(times) // frame_step) # Ensure at least 1 frame
ani = animation.FuncAnimation(
fig,
animate,
frames=frames + 1, # +1 to ensure we reach the end
init_func=init,
blit=True,
interval=30,
)
plt.tight_layout()
return plt.gca(), ani
# Uncomment to save animation
# ani.save('pi_collisions.mp4', writer='ffmpeg', fps=30)
return (create_animation,)
if __name__ == "__main__":
app.run()