High-Precision Expected Hitting Time for the Dice-Sum Process¶

This notebook implements the one-variable dynamic-programming method developed in the manuscript for computing the expected number of fair die rolls required for the cumulative sum to first hit a target set.

We specialize here to the perfect-square target set

$\mathcal H=\{n^2:n\in\mathbb N\}.$

The algorithm combines:

  • a truncated backward recursion for (E_N(0)),
  • a backward recursion for the overshoot probability (P_0(A_N)),
  • closed-form overshoot bounds (L_N) and (U_N),

yielding the certified approximation

$E(0)\approx E_N(0)+L_NP_0(A_N),$

with rigorous error bound

$0<E(0)-\bigl(E_N(0)+L_NP_0(A_N)\bigr)<(U_N-L_N)P_0(A_N).$

The implementation uses:

  • (O(N)) arithmetic operations,
  • (O(1)) memory via rolling-window updates,
  • arbitrary precision arithmetic through mpmath.

For the manuscript computation, we use $N=7000^2$ and 1200 decimal digits of working precision.

In [1]:
from mpmath import mp
from collections import deque

# ------------------------------------------------------------
# High-precision computation for the perfect-square target case
# Manuscript implementation
# ------------------------------------------------------------

# working precision
mp.dps = 1200

# ------------------------------------------------------------
# helpers: geometric sums
# ------------------------------------------------------------
def S0(r):
    """sum_{j>=0} r^j"""
    return 1 / (1 - r)


def S1(r):
    """sum_{j>=0} j r^j"""
    return r / (1 - r) ** 2


def S2(r):
    """sum_{j>=0} j^2 r^j"""
    return r * (1 + r) / (1 - r) ** 3


# ------------------------------------------------------------
# epsilon_N
# ------------------------------------------------------------
def epsilon_N(N):
    K = int(mp.sqrt(N))
    if K * K != N:
        raise ValueError("Require N = K^2.")
    if K < 4:
        raise ValueError("Require K >= 4.")

    # w_abs
    w_abs = mp.mpf("0.7302499667")

    return (mp.mpf("5") / 7) * (w_abs ** (2 * K - 4))


# ------------------------------------------------------------
# closed-form L_N and U_N
# ------------------------------------------------------------
def L_N_U_N_closed_form(N):
    K_int = int(mp.sqrt(N))
    if K_int * K_int != N:
        raise ValueError("Require N = K^2.")

    K = mp.mpf(K_int)

    eps = epsilon_N(N)
    five7 = mp.mpf("5") / 7
    two7 = mp.mpf("2") / 7

    def poly_sum(r, d):
        # Sum_{j>=0} ((K+1+j)^2 - K^2 - d) r^j
        # = Sum_{j>=0} [(2K+1-d) + 2(K+1)j + j^2] r^j
        a0 = 2 * K + 1 - d
        a1 = 2 * (K + 1)
        return a0 * S0(r) + a1 * S1(r) + S2(r)

    r_minus = five7 - eps
    r_plus = five7 + eps

    L = (mp.mpf("1") / 6) * (two7 - eps) * poly_sum(r_minus, d=5)
    U = (two7 + eps) * poly_sum(r_plus, d=1)

    return L, U


# ------------------------------------------------------------
# streaming E_N(0): O(N) time, O(1) memory
# ------------------------------------------------------------
def EN_at_zero_streaming(N):
    K = int(mp.sqrt(N))
    if K * K != N:
        raise ValueError("Require N = K^2.")

    k = K
    next_square = k * k
    sixth = mp.mpf("1") / 6

    # boundary values E_N(s)=0 for s>N
    buf = deque([mp.mpf("0")] * 6, maxlen=6)
    sum6 = mp.mpf("0")

    for s in range(N, -1, -1):

        if s == next_square:
            E_s = mp.mpf("0")
            k -= 1
            next_square = k * k if k > 0 else -1
        else:
            E_s = 1 + sixth * sum6

        oldest = buf.pop()
        sum6 -= oldest
        buf.appendleft(E_s)
        sum6 += E_s

    return E_s


# ------------------------------------------------------------
# streaming P_0(A_N): O(N) time, O(1) memory
# ------------------------------------------------------------
def P_overshoot_zero_streaming(N):
    K = int(mp.sqrt(N))
    if K * K != N:
        raise ValueError("Require N = K^2.")

    k = K
    next_square = k * k
    sixth = mp.mpf("1") / 6

    # boundary values P_s(A_N)=1 for s>N
    buf = deque([mp.mpf("1")] * 6, maxlen=6)
    sum6 = mp.mpf("6")

    for s in range(N, -1, -1):

        if s == next_square:
            Q_s = mp.mpf("0")
            k -= 1
            next_square = k * k if k > 0 else -1
        else:
            Q_s = sixth * sum6

        oldest = buf.pop()
        sum6 -= oldest
        buf.appendleft(Q_s)
        sum6 += Q_s

    return Q_s


# ------------------------------------------------------------
# Theorem 2 objects
# ------------------------------------------------------------
def truncation_objects_fast(N):
    EN0 = EN_at_zero_streaming(N)
    P0A = P_overshoot_zero_streaming(N)
    L, U = L_N_U_N_closed_form(N)

    approx = EN0 + L * P0A
    err_bd = (U - L) * P0A

    return approx, err_bd, EN0, P0A, L, U
In [2]:
# ------------------------------------------------------------
# example run (Theorem 2)
# ------------------------------------------------------------
import time

if __name__ == "__main__":

    N = 7000**2

    start_time = time.time()

    approx, err_bd, EN0, P0A, L, U = truncation_objects_fast(N)

    end_time = time.time()
    runtime = end_time - start_time

    print("Approximation E(0):E_N(0) + L_N * P_0(A_N)=")
    print(mp.nstr(approx, 1019))

    print("\nRigorous error bound: (U_N - L_N) * P_0(A_N)=")
    print(mp.nstr(err_bd, 40))

    print("\nE_N(0) =", mp.nstr(EN0, 40))
    print("P_0(A_N) =", mp.nstr(P0A, 40))
    print("L_N =", mp.nstr(L, 40))
    print("U_N =", mp.nstr(U, 40))

    print("\nRuntime =", round(runtime, 2), "seconds")
Approximation E(0):E_N(0) + L_N * P_0(A_N)=
7.0797642375511051038955530569081848946817114442632088059088731015172930306366572891506194402159295861406438530582366178390388054374270371619832251988435018692956813776498234440715233888008820745531068102279351912201497399312969543765589331921953693949583510111531141117999190881385051385993572642734582955346536537055487204771303737046494496070462752088408207916153631835937869840855942020528844752082478429005182914578014262554948325908230305047748136841290303836186610919947293463168991266582586867447217766236921643399987864860706302563591722932139301311266606613053537270912752830338219095963711644633463513431432765595336790889294337954095920737733995182964165404702948953236236292249929974777608530539038939897313532871879311367044360942347466466639767034394677123411717186190853174073085523878990940735330862620576871435574740617634573981362411821384208149298964783485461265863125045090486560427361918575110987791166131481796648503799789876560916765012999442010867907907370715432893078841972719702050899067753875

Rigorous error bound: (U_N - L_N) * P_0(A_N)=
6.163754194086475579815333888354168235404e-1019

E_N(0) = 7.079764237551105103895553056908184894682
P_0(A_N) = 1.508850331472307815412722898448210123557e-1023
L_N = 8169.333333333333333333333333333333333333
U_N = 49020.0

Runtime = 653.18 seconds