import%20marimo%0A%0A__generated_with%20%3D%20%220.15.2%22%0Aapp%20%3D%20marimo.App(width%3D%22full%22%2C%20app_title%3D%22Spring-Mass%20ODE%20in%20Phydrax%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20time%20as%20time_mod%0A%0A%20%20%20%20import%20equinox%20as%20eqx%0A%20%20%20%20import%20jax%0A%20%20%20%20import%20jax.numpy%20as%20jnp%0A%20%20%20%20import%20jax.random%20as%20jr%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20import%20optax%0A%0A%20%20%20%20import%20phydrax%20as%20phx%0A%0A%20%20%20%20plt.style.use(%22seaborn-v0_8-whitegrid%22)%0A%20%20%20%20return%20eqx%2C%20jax%2C%20jnp%2C%20jr%2C%20mo%2C%20optax%2C%20phx%2C%20plt%2C%20time_mod%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%23%20Coupled%20Spring-Mass%20ODE%20in%20Phydrax%0A%0A%20%20%20%20This%20notebook%20recreates%20the%20NVIDIA%20PhysicsNeMo%20spring-mass%20ODE%20benchmark%3A%0A%20%20%20%20%5BPhysicsNeMo%20Spring-Mass%20Example%5D(https%3A%2F%2Fdocs.nvidia.com%2Fphysicsnemo%2F25.11%2Fphysicsnemo-sym%2Fuser_guide%2Ffoundational%2Fode_spring_mass.html).%0A%0A%20%20%20%20Matrix%20form%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cddot%7Bx%7D(t)%20%2B%20Kx(t)%3D0%2C%5Cquad%0A%20%20%20%20x(t)%3D%5Cbegin%7Bbmatrix%7Dx_1(t)%5C%5Cx_2(t)%5C%5Cx_3(t)%5Cend%7Bbmatrix%7D%2C%5Cquad%0A%20%20%20%20K%3D%5Cbegin%7Bbmatrix%7D%0A%20%20%20%203%20%26%20-1%20%26%200%5C%5C%0A%20%20%20%20-1%20%26%202%20%26%20-1%5C%5C%0A%20%20%20%200%20%26%20-1%20%26%203%0A%20%20%20%20%5Cend%7Bbmatrix%7D.%0A%20%20%20%20%24%24%0A%0A%20%20%20%20We%20solve%20the%203-DOF%20system%20on%20%5C(t%5Cin%5B0%2C10%5D%5C)%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cddot%7Bx%7D_1%20%2B%203x_1%20-%20x_2%20%26%3D%200%2C%20%5C%5C%0A%20%20%20%20%5Cddot%7Bx%7D_2%20%2B%202x_2%20-%20x_1%20-%20x_3%20%26%3D%200%2C%20%5C%5C%0A%20%20%20%20%5Cddot%7Bx%7D_3%20%2B%203x_3%20-%20x_2%20%26%3D%200.%0A%20%20%20%20%5Cend%7Baligned%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Internally%2C%20training%20is%20performed%20on%20normalized%20time%0A%20%20%20%20%5C(s%3D(t-t_%7B%5Cmin%7D)%2F(t_%7B%5Cmax%7D-t_%7B%5Cmin%7D)%5Cin%5B0%2C1%5D%5C)%2C%20with%20the%20residual%20scaled%20so%0A%20%20%20%20the%20learned%20solution%20matches%20the%20same%20physical%20system%20in%20%5C(t%5C).%0A%0A%20%20%20%20Initial%20conditions%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20x_1(0)%3D1%2C%5C%20x_2(0)%3D0%2C%5C%20x_3(0)%3D0%2C%5Cquad%0A%20%20%20%20%5Cdot%7Bx%7D_1(0)%3D%5Cdot%7Bx%7D_2(0)%3D%5Cdot%7Bx%7D_3(0)%3D0.%0A%20%20%20%20%24%24%0A%0A%20%20%20%20We%20enforce%20all%20six%20initial%20conditions%20exactly%20by%20construction%2C%20then%20train%20with%0A%20%20%20%20a%20single%20matrix-form%20residual%20loss.%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.callout(%0A%20%20%20%20%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%23%20Why%20this%20setup%20is%20strong%0A%0A%20%20%20%20%20%20%20%20%20%20%20%201.%20ICs%20are%20enforced%20directly%20in%20the%20ansatz%20(no%20IC%20penalty%20balancing).%0A%20%20%20%20%20%20%20%20%20%20%20%202.%20Training%20objective%20focuses%20on%20residual%20physics%20only.%0A%20%20%20%20%20%20%20%20%20%20%20%203.%20Time%20derivatives%20use%20AD%20with%20JVP%20(%60dt_n(...%2C%20order%3D2%2C%20ad_engine%3D%22jvp%22)%60).%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20kind%3D%22success%22%2C%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20%23%20Configurations%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20width_size%20%3D%2020%0A%20%20%20%20depth%20%3D%206%0A%20%20%20%20num_iter%20%3D%20500%0A%20%20%20%20learning_rate%20%3D%201e-3%0A%20%20%20%20num_t_interior%20%3D%2010_000%0A%20%20%20%20nt_plot%20%3D%202000%0A%20%20%20%20seed%20%3D%200%0A%20%20%20%20t_min%20%3D%200.0%0A%20%20%20%20t_max%20%3D%2010.0%0A%0A%20%20%20%20%23%20Published%20PhysicsNeMo%20spring-mass%20baseline%20specs%0A%20%20%20%20physicsnemo_steps%20%3D%2050_000%0A%20%20%20%20physicsnemo_interior_batch%20%3D%20500%0A%20%20%20%20physicsnemo_params%20%3D%201_315_843%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20depth%2C%0A%20%20%20%20%20%20%20%20learning_rate%2C%0A%20%20%20%20%20%20%20%20nt_plot%2C%0A%20%20%20%20%20%20%20%20num_iter%2C%0A%20%20%20%20%20%20%20%20num_t_interior%2C%0A%20%20%20%20%20%20%20%20physicsnemo_interior_batch%2C%0A%20%20%20%20%20%20%20%20physicsnemo_params%2C%0A%20%20%20%20%20%20%20%20physicsnemo_steps%2C%0A%20%20%20%20%20%20%20%20seed%2C%0A%20%20%20%20%20%20%20%20t_max%2C%0A%20%20%20%20%20%20%20%20t_min%2C%0A%20%20%20%20%20%20%20%20width_size%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(eqx%2C%20jax%2C%20jnp%2C%20phx)%3A%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20%23%20Utility%20functions%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20def%20exact_states(t%3A%20jax.Array)%20-%3E%20jax.Array%3A%0A%20%20%20%20%20%20%20%20c1%20%3D%20jnp.cos(t)%0A%20%20%20%20%20%20%20%20c3%20%3D%20jnp.cos(jnp.sqrt(3.0)%20*%20t)%0A%20%20%20%20%20%20%20%20c2%20%3D%20jnp.cos(2.0%20*%20t)%0A%20%20%20%20%20%20%20%20x1%20%3D%20(1.0%20%2F%206.0)%20*%20c1%20%2B%200.5%20*%20c3%20%2B%20(1.0%20%2F%203.0)%20*%20c2%0A%20%20%20%20%20%20%20%20x2%20%3D%20(1.0%20%2F%203.0)%20*%20c1%20-%20(1.0%20%2F%203.0)%20*%20c2%0A%20%20%20%20%20%20%20%20x3%20%3D%20(1.0%20%2F%206.0)%20*%20c1%20-%200.5%20*%20c3%20%2B%20(1.0%20%2F%203.0)%20*%20c2%0A%20%20%20%20%20%20%20%20return%20jnp.stack((x1%2C%20x2%2C%20x3)%2C%20axis%3D-1)%0A%0A%20%20%20%20def%20exact_velocities(t%3A%20jax.Array)%20-%3E%20jax.Array%3A%0A%20%20%20%20%20%20%20%20s1%20%3D%20jnp.sin(t)%0A%20%20%20%20%20%20%20%20s3%20%3D%20jnp.sin(jnp.sqrt(3.0)%20*%20t)%0A%20%20%20%20%20%20%20%20s2%20%3D%20jnp.sin(2.0%20*%20t)%0A%20%20%20%20%20%20%20%20v1%20%3D%20-(1.0%20%2F%206.0)%20*%20s1%20-%200.5%20*%20jnp.sqrt(3.0)%20*%20s3%20-%20(2.0%20%2F%203.0)%20*%20s2%0A%20%20%20%20%20%20%20%20v2%20%3D%20-(1.0%20%2F%203.0)%20*%20s1%20%2B%20(2.0%20%2F%203.0)%20*%20s2%0A%20%20%20%20%20%20%20%20v3%20%3D%20-(1.0%20%2F%206.0)%20*%20s1%20%2B%200.5%20*%20jnp.sqrt(3.0)%20*%20s3%20-%20(2.0%20%2F%203.0)%20*%20s2%0A%20%20%20%20%20%20%20%20return%20jnp.stack((v1%2C%20v2%2C%20v3)%2C%20axis%3D-1)%0A%0A%20%20%20%20def%20total_energy(x%3A%20jax.Array%2C%20v%3A%20jax.Array)%20-%3E%20jax.Array%3A%0A%20%20%20%20%20%20%20%20x1%20%3D%20x%5B%3A%2C%200%5D%0A%20%20%20%20%20%20%20%20x2%20%3D%20x%5B%3A%2C%201%5D%0A%20%20%20%20%20%20%20%20x3%20%3D%20x%5B%3A%2C%202%5D%0A%20%20%20%20%20%20%20%20kinetic%20%3D%200.5%20*%20jnp.sum(v**2%2C%20axis%3D1)%0A%20%20%20%20%20%20%20%20potential%20%3D%20(x1**2)%20%2B%200.5%20*%20(x2%20-%20x1)%20**%202%20%2B%200.5%20*%20(x3%20-%20x2)%20**%202%20%2B%20(x3**2)%0A%20%20%20%20%20%20%20%20return%20kinetic%20%2B%20potential%0A%0A%20%20%20%20def%20count_parameters(module)%20-%3E%20int%3A%0A%20%20%20%20%20%20%20%20leaves%20%3D%20jax.tree_util.tree_leaves(eqx.filter(module%2C%20eqx.is_inexact_array))%0A%20%20%20%20%20%20%20%20return%20sum(leaf.size%20for%20leaf%20in%20leaves)%0A%0A%20%20%20%20def%20make_solver(%0A%20%20%20%20%20%20%20%20*%2C%0A%20%20%20%20%20%20%20%20t_min%3A%20float%2C%0A%20%20%20%20%20%20%20%20t_max%3A%20float%2C%0A%20%20%20%20%20%20%20%20width_size%3A%20int%2C%0A%20%20%20%20%20%20%20%20depth%3A%20int%2C%0A%20%20%20%20%20%20%20%20num_t_interior%3A%20int%2C%0A%20%20%20%20%20%20%20%20key%3A%20jax.Array%2C%0A%20%20%20%20)%20-%3E%20tuple%5Bphx.solver.FunctionalSolver%2C%20int%5D%3A%0A%20%20%20%20%20%20%20%20t_min_phys%20%3D%20t_min%0A%20%20%20%20%20%20%20%20t_max_phys%20%3D%20t_max%0A%20%20%20%20%20%20%20%20t_span%20%3D%20t_max_phys%20-%20t_min_phys%0A%20%20%20%20%20%20%20%20if%20t_span%20%3C%3D%200.0%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Expected%20positive%20physical%20time%20span%2C%20got%20t_min%3D%7Bt_min_phys%7D%2C%20t_max%3D%7Bt_max_phys%7D.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20Train%20on%20normalized%20time%20s%20in%20%5B0%2C%201%5D%2C%20while%20preserving%20physical%20ODE%20scaling.%0A%20%20%20%20%20%20%20%20time_domain%20%3D%20phx.domain.TimeInterval(0.0%2C%201.0)%0A%20%20%20%20%20%20%20%20structure_t%20%3D%20phx.domain.ProductStructure(((%22t%22%2C)%2C))%0A%20%20%20%20%20%20%20%20x_model%20%3D%20phx.nn.SeparableMLP(%0A%20%20%20%20%20%20%20%20%20%20%20%20in_size%3D%22scalar%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20out_size%3D3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20split_input%3D3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20width_size%3Dwidth_size%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20depth%3Ddepth%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20activation%3Dphx.nn.Stan(width_size)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20key%3Dkey%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20x_raw%20%3D%20time_domain.Model(%22t%22)(x_model)%0A%0A%20%20%20%20%20%20%20%20%40time_domain.Function(%22t%22)%0A%20%20%20%20%20%20%20%20def%20tau(t)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20isinstance(t%2C%20tuple)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20len(t)%20!%3D%201%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Expected%20a%20single%20scalar-axis%20tuple%20input%2C%20got%20%7Blen(t)%7D%20axes.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20t%20%3D%20t%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20t%0A%0A%20%20%20%20%20%20%20%20%40time_domain.Function()%0A%20%20%20%20%20%20%20%20def%20x_anchor()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20jnp.asarray(%5B1.0%2C%200.0%2C%200.0%5D%2C%20dtype%3Dfloat)%0A%0A%20%20%20%20%20%20%20%20%23%20Hard%20IC%20ansatz%20is%20intentionally%20used%20here%3A%20it%20yields%20better%20residual%0A%20%20%20%20%20%20%20%20%23%20conditioning%20for%20this%20ODE%20than%20the%20generic%20initial-overlay%20parameterization.%0A%20%20%20%20%20%20%20%20tau2%20%3D%20tau%20*%20tau%0A%20%20%20%20%20%20%20%20x%20%3D%20x_anchor%20%2B%20tau2%20*%20x_raw%0A%0A%20%20%20%20%20%20%20%20k_mat%20%3D%20jnp.asarray(%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B3.0%2C%20-1.0%2C%200.0%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B-1.0%2C%202.0%2C%20-1.0%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.0%2C%20-1.0%2C%203.0%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20dtype%3Dfloat%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20residual%20%3D%20phx.constraints.ContinuousPointwiseInteriorConstraint(%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20time_domain%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20operator%3Dlambda%20x%3A%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(1.0%20%2F%20(t_span%20*%20t_span))%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20*%20phx.operators.dt_n(x%2C%20var%3D%22t%22%2C%20order%3D2%2C%20ad_engine%3D%22jvp%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%2B%20phx.operators.einsum(%22ij%2C...j-%3E...i%22%2C%20k_mat%2C%20x)%0A%20%20%20%20%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20num_points%3Dnum_t_interior%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20structure%3Dstructure_t%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20reduction%3D%22mean%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20label%3D%22ode_matrix%22%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20solver%20%3D%20phx.solver.FunctionalSolver(%0A%20%20%20%20%20%20%20%20%20%20%20%20functions%3D%7B%22x%22%3A%20x%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20constraints%3D%5Bresidual%5D%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20model_params%20%3D%20count_parameters(x_model)%0A%20%20%20%20%20%20%20%20return%20solver%2C%20model_params%0A%0A%20%20%20%20def%20evaluate_solver(%0A%20%20%20%20%20%20%20%20solver%3A%20phx.solver.FunctionalSolver%2C%0A%20%20%20%20%20%20%20%20*%2C%0A%20%20%20%20%20%20%20%20t_min%3A%20float%2C%0A%20%20%20%20%20%20%20%20t_max%3A%20float%2C%0A%20%20%20%20%20%20%20%20nt%3A%20int%2C%0A%20%20%20%20)%20-%3E%20tuple%5Bjax.Array%2C%20jax.Array%2C%20jax.Array%2C%20jax.Array%2C%20jax.Array%2C%20jax.Array%5D%3A%0A%20%20%20%20%20%20%20%20fields%20%3D%20solver.ansatz_functions()%0A%20%20%20%20%20%20%20%20t_min_phys%20%3D%20t_min%0A%20%20%20%20%20%20%20%20t_max_phys%20%3D%20t_max%0A%20%20%20%20%20%20%20%20t_span%20%3D%20t_max_phys%20-%20t_min_phys%0A%20%20%20%20%20%20%20%20if%20t_span%20%3C%3D%200.0%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Expected%20positive%20physical%20time%20span%2C%20got%20t_min%3D%7Bt_min_phys%7D%2C%20t_max%3D%7Bt_max_phys%7D.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20t%20%3D%20jnp.linspace(t_min_phys%2C%20t_max_phys%2C%20nt)%0A%20%20%20%20%20%20%20%20s%20%3D%20(t%20-%20t_min_phys)%20%2F%20t_span%0A%0A%20%20%20%20%20%20%20%20x_fun%20%3D%20fields%5B%22x%22%5D%0A%20%20%20%20%20%20%20%20x_pred%20%3D%20jax.vmap(lambda%20s_i%3A%20x_fun.func(s_i))(s)%0A%0A%20%20%20%20%20%20%20%20v_s_fun%20%3D%20phx.operators.dt(x_fun%2C%20var%3D%22t%22%2C%20ad_engine%3D%22jvp%22)%0A%20%20%20%20%20%20%20%20v_s_pred%20%3D%20jax.vmap(lambda%20s_i%3A%20v_s_fun.func(s_i))(s)%0A%20%20%20%20%20%20%20%20v_pred%20%3D%20v_s_pred%20%2F%20t_span%0A%0A%20%20%20%20%20%20%20%20x_true%20%3D%20exact_states(t)%0A%20%20%20%20%20%20%20%20v_true%20%3D%20exact_velocities(t)%0A%20%20%20%20%20%20%20%20err%20%3D%20x_pred%20-%20x_true%0A%20%20%20%20%20%20%20%20return%20t%2C%20x_pred%2C%20x_true%2C%20err%2C%20v_pred%2C%20v_true%0A%0A%20%20%20%20def%20initial_condition_errors(%0A%20%20%20%20%20%20%20%20solver%3A%20phx.solver.FunctionalSolver%2C%0A%20%20%20%20%20%20%20%20*%2C%0A%20%20%20%20%20%20%20%20t0_phys%3A%20float%2C%0A%20%20%20%20%20%20%20%20t_min%3A%20float%2C%0A%20%20%20%20%20%20%20%20t_max%3A%20float%2C%0A%20%20%20%20)%20-%3E%20dict%5Bstr%2C%20float%5D%3A%0A%20%20%20%20%20%20%20%20t_min_phys%20%3D%20t_min%0A%20%20%20%20%20%20%20%20t_max_phys%20%3D%20t_max%0A%20%20%20%20%20%20%20%20t_span%20%3D%20t_max_phys%20-%20t_min_phys%0A%20%20%20%20%20%20%20%20if%20t_span%20%3C%3D%200.0%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Expected%20positive%20physical%20time%20span%2C%20got%20t_min%3D%7Bt_min_phys%7D%2C%20t_max%3D%7Bt_max_phys%7D.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20s0%20%3D%20(t0_phys%20-%20t_min_phys)%20%2F%20t_span%0A%20%20%20%20%20%20%20%20fields%20%3D%20solver.ansatz_functions()%0A%20%20%20%20%20%20%20%20x_fun%20%3D%20fields%5B%22x%22%5D%0A%20%20%20%20%20%20%20%20v_s_fun%20%3D%20phx.operators.dt(x_fun%2C%20var%3D%22t%22%2C%20ad_engine%3D%22jvp%22)%0A%0A%20%20%20%20%20%20%20%20x0%20%3D%20jnp.asarray(x_fun.func(s0))%0A%20%20%20%20%20%20%20%20v0%20%3D%20jnp.asarray(v_s_fun.func(s0))%20%2F%20t_span%0A%0A%20%20%20%20%20%20%20%20return%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x1_0_error%22%3A%20float(jnp.abs(x0%5B0%5D%20-%201.0))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x2_0_error%22%3A%20float(jnp.abs(x0%5B1%5D%20-%200.0))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x3_0_error%22%3A%20float(jnp.abs(x0%5B2%5D%20-%200.0))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22v1_0_error%22%3A%20float(jnp.abs(v0%5B0%5D%20-%200.0))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22v2_0_error%22%3A%20float(jnp.abs(v0%5B1%5D%20-%200.0))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22v3_0_error%22%3A%20float(jnp.abs(v0%5B2%5D%20-%200.0))%2C%0A%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20evaluate_solver%2C%0A%20%20%20%20%20%20%20%20exact_states%2C%0A%20%20%20%20%20%20%20%20initial_condition_errors%2C%0A%20%20%20%20%20%20%20%20make_solver%2C%0A%20%20%20%20%20%20%20%20total_energy%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20depth%2C%0A%20%20%20%20jr%2C%0A%20%20%20%20learning_rate%2C%0A%20%20%20%20make_solver%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20num_iter%2C%0A%20%20%20%20num_t_interior%2C%0A%20%20%20%20optax%2C%0A%20%20%20%20seed%2C%0A%20%20%20%20t_max%2C%0A%20%20%20%20t_min%2C%0A%20%20%20%20time_mod%2C%0A%20%20%20%20width_size%2C%0A)%3A%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20%23%20Main%20execution%20path%0A%20%20%20%20%23%20-------------------------------------------------------------------------%0A%20%20%20%20solver%2C%20our_params%20%3D%20make_solver(%0A%20%20%20%20%20%20%20%20t_min%3Dt_min%2C%0A%20%20%20%20%20%20%20%20t_max%3Dt_max%2C%0A%20%20%20%20%20%20%20%20width_size%3Dwidth_size%2C%0A%20%20%20%20%20%20%20%20depth%3Ddepth%2C%0A%20%20%20%20%20%20%20%20num_t_interior%3Dnum_t_interior%2C%0A%20%20%20%20%20%20%20%20key%3Djr.key(seed)%2C%0A%20%20%20%20)%0A%0A%20%20%20%20t0%20%3D%20time_mod.perf_counter()%0A%20%20%20%20init_loss%20%3D%20float(solver.loss(key%3Djr.key(seed%20%2B%201)))%0A%20%20%20%20trained_solver%20%3D%20solver.solve(%0A%20%20%20%20%20%20%20%20num_iter%3Dnum_iter%2C%0A%20%20%20%20%20%20%20%20optim%3Doptax.rprop(learning_rate)%2C%0A%20%20%20%20%20%20%20%20seed%3Dseed%2C%0A%20%20%20%20%20%20%20%20jit%3DTrue%2C%0A%20%20%20%20%20%20%20%20keep_best%3DTrue%2C%0A%20%20%20%20)%0A%20%20%20%20final_loss%20%3D%20float(trained_solver.loss(key%3Djr.key(seed%20%2B%202)))%0A%20%20%20%20elapsed%20%3D%20time_mod.perf_counter()%20-%20t0%0A%0A%20%20%20%20train_stats%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22init_loss%22%3A%20init_loss%2C%0A%20%20%20%20%20%20%20%20%22final_loss%22%3A%20final_loss%2C%0A%20%20%20%20%20%20%20%20%22elapsed_s%22%3A%20elapsed%2C%0A%20%20%20%20%20%20%20%20%22s_per_iter%22%3A%20elapsed%20%2F%20max(num_iter%2C%201)%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20train_status%20%3D%20mo.callout(%0A%20%20%20%20%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%E2%9C%85%20Training%20complete%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20-%20initial%20loss%3A%20%60%7Binit_loss%3A.6e%7D%60%0A%20%20%20%20%20%20%20%20%20%20%20%20-%20final%20loss%3A%20%60%7Bfinal_loss%3A.6e%7D%60%0A%20%20%20%20%20%20%20%20%20%20%20%20-%20elapsed%3A%20%60%7Belapsed%3A.2f%7Ds%60%20(%60%7Btrain_stats%5B%22s_per_iter%22%5D%3A.4f%7Ds%2Fiter%60)%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20kind%3D%22success%22%2C%0A%20%20%20%20)%0A%20%20%20%20train_status%0A%20%20%20%20return%20our_params%2C%20train_stats%2C%20trained_solver%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20evaluate_solver%2C%0A%20%20%20%20exact_states%2C%0A%20%20%20%20initial_condition_errors%2C%0A%20%20%20%20jnp%2C%0A%20%20%20%20nt_plot%2C%0A%20%20%20%20t_max%2C%0A%20%20%20%20t_min%2C%0A%20%20%20%20total_energy%2C%0A%20%20%20%20trained_solver%2C%0A)%3A%0A%20%20%20%20t_diag%2C%20x_pred_diag%2C%20x_true_diag%2C%20x_err_diag%2C%20v_pred_diag%2C%20v_true_diag%20%3D%20(%0A%20%20%20%20%20%20%20%20evaluate_solver(%0A%20%20%20%20%20%20%20%20%20%20%20%20trained_solver%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20t_min%3Dt_min%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20t_max%3Dt_max%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nt%3Dnt_plot%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20)%0A%0A%20%20%20%20ic_err%20%3D%20initial_condition_errors(%0A%20%20%20%20%20%20%20%20trained_solver%2C%0A%20%20%20%20%20%20%20%20t0_phys%3Dt_min%2C%0A%20%20%20%20%20%20%20%20t_min%3Dt_min%2C%0A%20%20%20%20%20%20%20%20t_max%3Dt_max%2C%0A%20%20%20%20)%0A%20%20%20%20v_err_diag%20%3D%20v_pred_diag%20-%20v_true_diag%0A%0A%20%20%20%20e_pred_diag%20%3D%20total_energy(x_pred_diag%2C%20v_pred_diag)%0A%20%20%20%20e_true_diag%20%3D%20total_energy(exact_states(t_diag)%2C%20v_true_diag)%0A%20%20%20%20e0_diag%20%3D%20jnp.maximum(jnp.abs(e_pred_diag%5B0%5D)%2C%201e-12)%0A%0A%20%20%20%20diag_stats%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22state_l2%22%3A%20float(jnp.sqrt(jnp.mean(x_err_diag**2)))%2C%0A%20%20%20%20%20%20%20%20%22state_linf%22%3A%20float(jnp.max(jnp.abs(x_err_diag)))%2C%0A%20%20%20%20%20%20%20%20%22velocity_l2%22%3A%20float(jnp.sqrt(jnp.mean(v_err_diag**2)))%2C%0A%20%20%20%20%20%20%20%20%22velocity_linf%22%3A%20float(jnp.max(jnp.abs(v_err_diag)))%2C%0A%20%20%20%20%20%20%20%20%22energy_abs_drift%22%3A%20float(jnp.max(jnp.abs(e_pred_diag%20-%20e_pred_diag%5B0%5D)))%2C%0A%20%20%20%20%20%20%20%20%22energy_rel_drift%22%3A%20float(%0A%20%20%20%20%20%20%20%20%20%20%20%20jnp.max(jnp.abs(e_pred_diag%20-%20e_pred_diag%5B0%5D))%20%2F%20e0_diag%0A%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20%22energy_true_abs_drift%22%3A%20float(jnp.max(jnp.abs(e_true_diag%20-%20e_true_diag%5B0%5D)))%2C%0A%20%20%20%20%20%20%20%20**ic_err%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20plot_data%20%3D%20(t_diag%2C%20x_pred_diag%2C%20x_true_diag%2C%20x_err_diag%2C%20e_pred_diag)%0A%20%20%20%20return%20diag_stats%2C%20plot_data%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(diag_stats%2C%20mo%2C%20train_stats)%3A%0A%20%20%20%20diagnostics_panel%20%3D%20mo.callout(%22Diagnostics%20unavailable.%22%2C%20kind%3D%22warn%22)%0A%20%20%20%20if%20diag_stats%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20msg%20%3D%20f%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%20Diagnostics%0A%0A%20%20%20%20%20%20%20%20-%20State%20L2%20error%3A%20%60%7Bdiag_stats%5B%22state_l2%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20State%20Linf%20error%3A%20%60%7Bdiag_stats%5B%22state_linf%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20Velocity%20L2%20error%3A%20%60%7Bdiag_stats%5B%22velocity_l2%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20Velocity%20Linf%20error%3A%20%60%7Bdiag_stats%5B%22velocity_linf%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20Energy%20abs%20drift%3A%20%60%7Bdiag_stats%5B%22energy_abs_drift%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20Energy%20rel%20drift%3A%20%60%7Bdiag_stats%5B%22energy_rel_drift%22%5D%3A.3e%7D%60%0A%0A%20%20%20%20%20%20%20%20Initial-condition%20residuals%3A%0A%20%20%20%20%20%20%20%20-%20%60%7Cx1(0)-1%7C%60%3A%20%60%7Bdiag_stats%5B%22x1_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20%60%7Cx2(0)-0%7C%60%3A%20%60%7Bdiag_stats%5B%22x2_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20%60%7Cx3(0)-0%7C%60%3A%20%60%7Bdiag_stats%5B%22x3_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20%60%7Cx1'(0)-0%7C%60%3A%20%60%7Bdiag_stats%5B%22v1_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20%60%7Cx2'(0)-0%7C%60%3A%20%60%7Bdiag_stats%5B%22v2_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20-%20%60%7Cx3'(0)-0%7C%60%3A%20%60%7Bdiag_stats%5B%22v3_0_error%22%5D%3A.3e%7D%60%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20%20%20%20%20if%20train_stats%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20msg%20%2B%3D%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22%5Cn-%20loss%20(init%20%E2%86%92%20final)%3A%20%60%7Btrain_stats%5B'init_loss'%5D%3A.3e%7D%60%20%E2%86%92%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22%60%7Btrain_stats%5B'final_loss'%5D%3A.3e%7D%60%22%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20diagnostics_panel%20%3D%20mo.md(msg)%0A%20%20%20%20diagnostics_panel%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(jnp%2C%20mo%2C%20plot_data%2C%20plt)%3A%0A%20%20%20%20plot_panel%20%3D%20mo.callout(%22No%20plot%20data%20available.%22%2C%20kind%3D%22warn%22)%0A%20%20%20%20if%20plot_data%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20t_plot%2C%20x_pred_plot%2C%20x_true_plot%2C%20x_err_plot%2C%20e_pred_plot%20%3D%20plot_data%0A%20%20%20%20%20%20%20%20max_abs_state_err%20%3D%20jnp.max(jnp.abs(x_err_plot)%2C%20axis%3D1)%0A%20%20%20%20%20%20%20%20abs_energy_drift%20%3D%20jnp.abs(e_pred_plot%20-%20e_pred_plot%5B0%5D)%0A%0A%20%20%20%20%20%20%20%20fig%2C%20axes%20%3D%20plt.subplots(2%2C%202%2C%20figsize%3D(11.6%2C%207.2)%2C%20constrained_layout%3DTrue)%0A%20%20%20%20%20%20%20%20ax0%20%3D%20axes%5B0%2C%200%5D%0A%20%20%20%20%20%20%20%20ax1%20%3D%20axes%5B0%2C%201%5D%0A%20%20%20%20%20%20%20%20ax2%20%3D%20axes%5B1%2C%200%5D%0A%20%20%20%20%20%20%20%20ax3%20%3D%20axes%5B1%2C%201%5D%0A%0A%20%20%20%20%20%20%20%20ax0.plot(t_plot%2C%20x_pred_plot%5B%3A%2C%200%5D%2C%20color%3D%22tab%3Ablue%22%2C%20lw%3D2.0%2C%20label%3D%22Phydrax%22)%0A%20%20%20%20%20%20%20%20ax0.plot(t_plot%2C%20x_true_plot%5B%3A%2C%200%5D%2C%20color%3D%22black%22%2C%20lw%3D1.2%2C%20ls%3D%22--%22%2C%20label%3D%22Exact%22)%0A%20%20%20%20%20%20%20%20ax0.set_title(%22x1(t)%22)%0A%20%20%20%20%20%20%20%20ax0.set_xlabel(%22t%22)%0A%20%20%20%20%20%20%20%20ax0.set_ylabel(%22displacement%22)%0A%20%20%20%20%20%20%20%20ax0.legend(loc%3D%22upper%20right%22)%0A%0A%20%20%20%20%20%20%20%20ax1.plot(t_plot%2C%20x_pred_plot%5B%3A%2C%201%5D%2C%20color%3D%22tab%3Aorange%22%2C%20lw%3D2.0%2C%20label%3D%22Phydrax%22)%0A%20%20%20%20%20%20%20%20ax1.plot(t_plot%2C%20x_true_plot%5B%3A%2C%201%5D%2C%20color%3D%22black%22%2C%20lw%3D1.2%2C%20ls%3D%22--%22%2C%20label%3D%22Exact%22)%0A%20%20%20%20%20%20%20%20ax1.set_title(%22x2(t)%22)%0A%20%20%20%20%20%20%20%20ax1.set_xlabel(%22t%22)%0A%20%20%20%20%20%20%20%20ax1.set_ylabel(%22displacement%22)%0A%20%20%20%20%20%20%20%20ax1.legend(loc%3D%22upper%20right%22)%0A%0A%20%20%20%20%20%20%20%20ax2.plot(t_plot%2C%20x_pred_plot%5B%3A%2C%202%5D%2C%20color%3D%22tab%3Agreen%22%2C%20lw%3D2.0%2C%20label%3D%22Phydrax%22)%0A%20%20%20%20%20%20%20%20ax2.plot(t_plot%2C%20x_true_plot%5B%3A%2C%202%5D%2C%20color%3D%22black%22%2C%20lw%3D1.2%2C%20ls%3D%22--%22%2C%20label%3D%22Exact%22)%0A%20%20%20%20%20%20%20%20ax2.set_title(%22x3(t)%22)%0A%20%20%20%20%20%20%20%20ax2.set_xlabel(%22t%22)%0A%20%20%20%20%20%20%20%20ax2.set_ylabel(%22displacement%22)%0A%20%20%20%20%20%20%20%20ax2.legend(loc%3D%22upper%20right%22)%0A%0A%20%20%20%20%20%20%20%20ax3.semilogy(%0A%20%20%20%20%20%20%20%20%20%20%20%20t_plot%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20jnp.asarray(max_abs_state_err)%20%2B%201e-15%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20color%3D%22tab%3Ared%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20lw%3D2.0%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20label%3D%22max%20%7Cstate%20error%7C%22%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20ax3.semilogy(%0A%20%20%20%20%20%20%20%20%20%20%20%20t_plot%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20jnp.asarray(abs_energy_drift)%20%2B%201e-15%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20color%3D%22tab%3Apurple%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20lw%3D1.7%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20label%3D%22%7CE(t)-E(0)%7C%22%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20ax3.set_title(%22Error%20and%20Invariant%20Drift%22)%0A%20%20%20%20%20%20%20%20ax3.set_xlabel(%22t%22)%0A%20%20%20%20%20%20%20%20ax3.set_ylabel(%22log%20scale%22)%0A%20%20%20%20%20%20%20%20ax3.legend(loc%3D%22upper%20right%22)%0A%0A%20%20%20%20%20%20%20%20plot_panel%20%3D%20mo.hstack(%5Bmo.md(%22%22)%2C%20fig%2C%20mo.md(%22%22)%5D%2C%20widths%3D%5B1%2C%208%2C%201%5D)%0A%20%20%20%20plot_panel%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20diag_stats%2C%0A%20%20%20%20jax%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20num_iter%2C%0A%20%20%20%20num_t_interior%2C%0A%20%20%20%20our_params%2C%0A%20%20%20%20physicsnemo_interior_batch%2C%0A%20%20%20%20physicsnemo_params%2C%0A%20%20%20%20physicsnemo_steps%2C%0A%20%20%20%20train_stats%2C%0A)%3A%0A%20%20%20%20our_steps%20%3D%20num_iter%0A%20%20%20%20our_points_per_step%20%3D%20num_t_interior%0A%20%20%20%20physicsnemo_interior%20%3D%20physicsnemo_interior_batch%0A%20%20%20%20batch_ratio%20%3D%20our_points_per_step%20%2F%20max(physicsnemo_interior%2C%201)%0A%0A%20%20%20%20step_reduction%20%3D%20100.0%20*%20((physicsnemo_steps%20-%20our_steps)%20%2F%20physicsnemo_steps)%0A%20%20%20%20param_reduction%20%3D%20100.0%20*%20((physicsnemo_params%20-%20our_params)%20%2F%20physicsnemo_params)%0A%20%20%20%20param_ratio%20%3D%20physicsnemo_params%20%2F%20max(our_params%2C%201)%0A%0A%20%20%20%20state_line%20%3D%20%22%22%0A%20%20%20%20if%20diag_stats%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20state_line%20%3D%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22%5Cn-%20Current%20notebook%20state%20Linf%20error%3A%20%60%7Bdiag_stats%5B'state_linf'%5D%3A.3e%7D%60%22%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20speed_line%20%3D%20%22%22%0A%20%20%20%20if%20train_stats%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20speed_line%20%3D%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22%5Cn-%20This%20run%20time%20per%20iteration%3A%20%60%7Btrain_stats%5B's_per_iter'%5D%3A.3f%7Ds%2Fiter%60%22%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20devices%20%3D%20tuple(jax.devices())%0A%20%20%20%20platform_counts%3A%20dict%5Bstr%2C%20int%5D%20%3D%20%7B%7D%0A%20%20%20%20device_kinds%3A%20list%5Bstr%5D%20%3D%20%5B%5D%0A%20%20%20%20for%20dev%20in%20devices%3A%0A%20%20%20%20%20%20%20%20platform%20%3D%20str(dev.platform)%0A%20%20%20%20%20%20%20%20platform_counts%5Bplatform%5D%20%3D%20platform_counts.get(platform%2C%200)%20%2B%201%0A%20%20%20%20%20%20%20%20device_kinds.append(str(dev.device_kind))%0A%20%20%20%20platform_summary%20%3D%20%22%2C%20%22.join(%0A%20%20%20%20%20%20%20%20f%22%7Bname%7D%3A%7Bcount%7D%22%20for%20name%2C%20count%20in%20sorted(platform_counts.items())%0A%20%20%20%20)%0A%20%20%20%20if%20not%20platform_summary%3A%0A%20%20%20%20%20%20%20%20platform_summary%20%3D%20%22unknown%22%0A%20%20%20%20kinds_summary%20%3D%20%22%2C%20%22.join(sorted(set(device_kinds)))%0A%20%20%20%20if%20not%20kinds_summary%3A%0A%20%20%20%20%20%20%20%20kinds_summary%20%3D%20%22unknown%22%0A%0A%20%20%20%20comparison_panel%20%3D%20mo.md(%0A%20%20%20%20%20%20%20%20f%22%22%22%0A%20%20%20%20%23%23%20Comparison%20with%20PhysicsNeMo%20Spring-Mass%20Example%0A%0A%20%20%20%20From%20NVIDIA's%20published%20PhysicsNeMo%20spring-mass%20example%3A%0A%0A%20%20%20%20-%20Docs%20page%3A%20https%3A%2F%2Fdocs.nvidia.com%2Fphysicsnemo%2F25.11%2Fphysicsnemo-sym%2Fuser_guide%2Ffoundational%2Fode_spring_mass.html%0A%20%20%20%20-%20Config%3A%20https%3A%2F%2Fgithub.com%2FNVIDIA%2Fphysicsnemo-sym%2Fblob%2Fmain%2Fexamples%2Fode_spring_mass%2Fconf%2Fconfig.yaml%0A%20%20%20%20-%20Solver%3A%20https%3A%2F%2Fgithub.com%2FNVIDIA%2Fphysicsnemo-sym%2Fblob%2Fmain%2Fexamples%2Fode_spring_mass%2Fspring_mass_solver.py%0A%20%20%20%20-%20FC%20arch%20defaults%3A%20https%3A%2F%2Fgithub.com%2FNVIDIA%2Fphysicsnemo-sym%2Fblob%2Fmain%2Fphysicsnemo%2Fsym%2Fhydra%2Farch.py%0A%0A%20%20%20%20Training%2Fconfig%20context%3A%0A%0A%20%20%20%20-%20PhysicsNeMo%20steps%3A%20%60%7Bphysicsnemo_steps%3A%2C%7D%60%0A%20%20%20%20-%20This%20notebook%20steps%3A%20%60%7Bour_steps%3A%2C%7D%60%20(**%60%7Bstep_reduction%3A.1f%7D%25%60%20fewer**)%0A%0A%20%20%20%20-%20PhysicsNeMo%20interior%20points%2Fstep%3A%20%60%7Bphysicsnemo_interior%3A%2C%7D%60%0A%20%20%20%20-%20This%20notebook%20interior%20points%2Fstep%3A%20%60%7Bour_points_per_step%3A%2C%7D%60%20(**%60%7Bbatch_ratio%3A.0f%7Dx%60%20larger**)%0A%0A%20%20%20%20-%20PhysicsNeMo%20parameter%20count%20(default%20FC)%3A%20%60%7Bphysicsnemo_params%3A%2C%7D%60%0A%20%20%20%20-%20This%20notebook%20parameter%20count%3A%20%60%7Bour_params%3A%2C%7D%60%20(**%60%7Bparam_ratio%3A.1f%7Dx%60%20smaller**%2C%20%60%7Bparam_reduction%3A.1f%7D%25%60%20fewer)%0A%0A%20%20%20%20Conditioning%20and%20invariants%3A%0A%0A%20%20%20%20-%20ICs%20are%20enforced%20by%20construction%20in%20this%20notebook%20(no%20soft%20IC%20loss%20balancing).%0A%20%20%20%20-%20Energy%20drift%20is%20explicitly%20monitored.%0A%20%20%20%20%7Bstate_line%7D%0A%20%20%20%20%7Bspeed_line%7D%0A%0A%20%20%20%20Hardware%20context%20for%20this%20run%3A%20MacBook%20Pro%20M1%20Max%0A%0A%20%20%20%20**tl%3Bdr**%3A%20On%20a%20laptop%20with%20**no%20dedicated%20GPU**%2C%20we%20are%20able%20to%20**locally**%20solve%20this%20example%0A%20%20%20%20from%20the%20PhysicsNeMo%20documentation%20in%20a%20couple%20of%20minutes%2C%20with%20**20x**%20the%20batch%20size%2C%20%0A%20%20%20%20at%20**1%25%20of%20the%20iterations**%2C%20with%20**99%25%20fewer%20parameters**%2C%20while%20exactly%20satisfying%20***multiple%20exact%20initial%20%0A%20%20%20%20conditions%20simultaneously***.%20All%20using%20out-of-the-box%20Phydrax%20components.%0A%0A%20%20%20%20*Interested%20in%20custom%20optimized%20software%20for%20your%20use%20case%3F%20Email%20us%20at%20partner%40phydra.ai*%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20comparison_panel%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20---%0A%20%20%20%20**Tips**%0A%0A%20%20%20%20-%20Run%20as%20a%20notebook%20editor%3A%20%60marimo%20edit%20examples%2Fspring_mass_ode.py%60%0A%20%20%20%20-%20Run%20as%20an%20app%3A%20%60python%20examples%2Fspring_mass_ode.py%60%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
a8a57cdb1c36bf432e0f16c5a51cedcc90cbde0f85291b397bd5b56ddf1c3bbb