This notebook implements the one-variable dynamic-programming method developed in the manuscript for computing the expected number of fair die rolls required for the cumulative sum to first hit a target set.
We specialize here to the perfect-square target set
$\mathcal H=\{n^2:n\in\mathbb N\}.$
The algorithm combines:
yielding the certified approximation
$E(0)\approx E_N(0)+L_NP_0(A_N),$
with rigorous error bound
$0<E(0)-\bigl(E_N(0)+L_NP_0(A_N)\bigr)<(U_N-L_N)P_0(A_N).$
The implementation uses:
mpmath.For the manuscript computation, we use $N=7000^2$ and 1200 decimal digits of working precision.
from mpmath import mp
from collections import deque
# ------------------------------------------------------------
# High-precision computation for the perfect-square target case
# Manuscript implementation
# ------------------------------------------------------------
# working precision
mp.dps = 1200
# ------------------------------------------------------------
# helpers: geometric sums
# ------------------------------------------------------------
def S0(r):
"""sum_{j>=0} r^j"""
return 1 / (1 - r)
def S1(r):
"""sum_{j>=0} j r^j"""
return r / (1 - r) ** 2
def S2(r):
"""sum_{j>=0} j^2 r^j"""
return r * (1 + r) / (1 - r) ** 3
# ------------------------------------------------------------
# epsilon_N
# ------------------------------------------------------------
def epsilon_N(N):
K = int(mp.sqrt(N))
if K * K != N:
raise ValueError("Require N = K^2.")
if K < 4:
raise ValueError("Require K >= 4.")
# w_abs
w_abs = mp.mpf("0.7302499667")
return (mp.mpf("5") / 7) * (w_abs ** (2 * K - 4))
# ------------------------------------------------------------
# closed-form L_N and U_N
# ------------------------------------------------------------
def L_N_U_N_closed_form(N):
K_int = int(mp.sqrt(N))
if K_int * K_int != N:
raise ValueError("Require N = K^2.")
K = mp.mpf(K_int)
eps = epsilon_N(N)
five7 = mp.mpf("5") / 7
two7 = mp.mpf("2") / 7
def poly_sum(r, d):
# Sum_{j>=0} ((K+1+j)^2 - K^2 - d) r^j
# = Sum_{j>=0} [(2K+1-d) + 2(K+1)j + j^2] r^j
a0 = 2 * K + 1 - d
a1 = 2 * (K + 1)
return a0 * S0(r) + a1 * S1(r) + S2(r)
r_minus = five7 - eps
r_plus = five7 + eps
L = (mp.mpf("1") / 6) * (two7 - eps) * poly_sum(r_minus, d=5)
U = (two7 + eps) * poly_sum(r_plus, d=1)
return L, U
# ------------------------------------------------------------
# streaming E_N(0): O(N) time, O(1) memory
# ------------------------------------------------------------
def EN_at_zero_streaming(N):
K = int(mp.sqrt(N))
if K * K != N:
raise ValueError("Require N = K^2.")
k = K
next_square = k * k
sixth = mp.mpf("1") / 6
# boundary values E_N(s)=0 for s>N
buf = deque([mp.mpf("0")] * 6, maxlen=6)
sum6 = mp.mpf("0")
for s in range(N, -1, -1):
if s == next_square:
E_s = mp.mpf("0")
k -= 1
next_square = k * k if k > 0 else -1
else:
E_s = 1 + sixth * sum6
oldest = buf.pop()
sum6 -= oldest
buf.appendleft(E_s)
sum6 += E_s
return E_s
# ------------------------------------------------------------
# streaming P_0(A_N): O(N) time, O(1) memory
# ------------------------------------------------------------
def P_overshoot_zero_streaming(N):
K = int(mp.sqrt(N))
if K * K != N:
raise ValueError("Require N = K^2.")
k = K
next_square = k * k
sixth = mp.mpf("1") / 6
# boundary values P_s(A_N)=1 for s>N
buf = deque([mp.mpf("1")] * 6, maxlen=6)
sum6 = mp.mpf("6")
for s in range(N, -1, -1):
if s == next_square:
Q_s = mp.mpf("0")
k -= 1
next_square = k * k if k > 0 else -1
else:
Q_s = sixth * sum6
oldest = buf.pop()
sum6 -= oldest
buf.appendleft(Q_s)
sum6 += Q_s
return Q_s
# ------------------------------------------------------------
# Theorem 2 objects
# ------------------------------------------------------------
def truncation_objects_fast(N):
EN0 = EN_at_zero_streaming(N)
P0A = P_overshoot_zero_streaming(N)
L, U = L_N_U_N_closed_form(N)
approx = EN0 + L * P0A
err_bd = (U - L) * P0A
return approx, err_bd, EN0, P0A, L, U
# ------------------------------------------------------------
# example run (Theorem 2)
# ------------------------------------------------------------
import time
if __name__ == "__main__":
N = 7000**2
start_time = time.time()
approx, err_bd, EN0, P0A, L, U = truncation_objects_fast(N)
end_time = time.time()
runtime = end_time - start_time
print("Approximation E(0):E_N(0) + L_N * P_0(A_N)=")
print(mp.nstr(approx, 1019))
print("\nRigorous error bound: (U_N - L_N) * P_0(A_N)=")
print(mp.nstr(err_bd, 40))
print("\nE_N(0) =", mp.nstr(EN0, 40))
print("P_0(A_N) =", mp.nstr(P0A, 40))
print("L_N =", mp.nstr(L, 40))
print("U_N =", mp.nstr(U, 40))
print("\nRuntime =", round(runtime, 2), "seconds")
Approximation E(0):E_N(0) + L_N * P_0(A_N)= 7.0797642375511051038955530569081848946817114442632088059088731015172930306366572891506194402159295861406438530582366178390388054374270371619832251988435018692956813776498234440715233888008820745531068102279351912201497399312969543765589331921953693949583510111531141117999190881385051385993572642734582955346536537055487204771303737046494496070462752088408207916153631835937869840855942020528844752082478429005182914578014262554948325908230305047748136841290303836186610919947293463168991266582586867447217766236921643399987864860706302563591722932139301311266606613053537270912752830338219095963711644633463513431432765595336790889294337954095920737733995182964165404702948953236236292249929974777608530539038939897313532871879311367044360942347466466639767034394677123411717186190853174073085523878990940735330862620576871435574740617634573981362411821384208149298964783485461265863125045090486560427361918575110987791166131481796648503799789876560916765012999442010867907907370715432893078841972719702050899067753875 Rigorous error bound: (U_N - L_N) * P_0(A_N)= 6.163754194086475579815333888354168235404e-1019 E_N(0) = 7.079764237551105103895553056908184894682 P_0(A_N) = 1.508850331472307815412722898448210123557e-1023 L_N = 8169.333333333333333333333333333333333333 U_N = 49020.0 Runtime = 653.18 seconds