module DiceRolling export d4, d6, d8, d10, d12, d20, d4adv, d6adv, d8adv, d10adv, d12adv, d20adv, d4dis, d6dis, d8dis, d10dis, d12dis, d20dis, distribution, pte using FFTW using OffsetArrays using Requires abstract type AbstractDice end abstract type BaseDie{F} <: AbstractDice end struct Die{F} <: BaseDie{F} end struct AdvantageDie{F} <: BaseDie{F} end struct DisadvantageDie{F} <: BaseDie{F} end struct DieRolls <: AbstractDice rolls::Int die::BaseDie function DieRolls(rolls::Integer, die::BaseDie{F}) where F if F < 0 return new(-rolls, typeof(die){-F}()) else return new(rolls, die) end end end DieRolls(rolls::Integer, die::Integer) = DieRolls(rolls, Die{die}()) struct CompoundDieRolls <: AbstractDice die::Vector{DieRolls} bonus::Int function CompoundDieRolls(d::CompoundDieRolls) return new(copy(d.die), d.bonus) end function CompoundDieRolls(dice::AbstractVector{T}, bonus::Int = 0) where T <: Union{BaseDie, DieRolls} rolls = Dict{BaseDie,Int}() for d in map(x -> 1x, dice) r = get!(rolls, d.die, 0) rolls[d.die] += dierolls(d) end dice′ = [DieRolls(r, d) for (d, r) in rolls] return new(dice′, bonus) end end CompoundDieRolls(d::BaseDie, bonus::Int = 0) = CompoundDieRolls([1d], bonus) CompoundDieRolls(d::DieRolls, bonus::Int = 0) = CompoundDieRolls([d], bonus) # Accessors diebase(d::BaseDie{F}) where {F} = F diebase(d::DieRolls) = diebase(d.die) dierolls(d::BaseDie) = 1 dierolls(d::DieRolls) = d.rolls diebonus(::BaseDie) = 0 diebonus(::DieRolls) = 0 diebonus(d::CompoundDieRolls) = d.bonus # The fundamental D&D dice const d4 = Die{4}() const d6 = Die{6}() const d8 = Die{8}() const d10 = Die{10}() const d12 = Die{12}() const d20 = Die{20}() # with advantage const d4adv = AdvantageDie{4}() const d6adv = AdvantageDie{6}() const d8adv = AdvantageDie{8}() const d10adv = AdvantageDie{10}() const d12adv = AdvantageDie{12}() const d20adv = AdvantageDie{20}() # and disadvantage const d4dis = DisadvantageDie{4}() const d6dis = DisadvantageDie{6}() const d8dis = DisadvantageDie{8}() const d10dis = DisadvantageDie{10}() const d12dis = DisadvantageDie{12}() const d20dis = DisadvantageDie{20}() # Define math over dice objects Base.:-(d::BaseDie) = DieRolls(-1, d) Base.:-(d::DieRolls) = DieRolls(-d.rolls, d.die) Base.:-(d::CompoundDieRolls) = CompoundDieRolls(map(-, d.die), -d.bonus) Base.:+(d::BaseDie) = DieRolls(1, d) Base.:+(d::DieRolls) = d Base.:+(d::CompoundDieRolls) = d Base.:*(r::Integer, d::BaseDie) = DieRolls(Int(r), d) Base.:*(r::Integer, d::DieRolls) = DieRolls(Int(r * d.rolls), d.die) Base.:*(r::Integer, d::CompoundDieRolls) = CompoundDieRolls(map(d -> r * d, d.die), r * d.bonus) Base.:+(b::Integer, d::AbstractDice) = d + b Base.:-(b::Integer, d::AbstractDice) = -d + b Base.:+(d::AbstractDice, b::Integer) = CompoundDieRolls([d], b) Base.:-(d::AbstractDice, b::Integer) = CompoundDieRolls([d], -b) Base.:+(d1::AbstractDice, d2::AbstractDice) = CompoundDieRolls(d1) + CompoundDieRolls(d2) Base.:-(d1::AbstractDice, d2::AbstractDice) = CompoundDieRolls(d1) - CompoundDieRolls(d2) Base.:+(d::CompoundDieRolls, b::Integer) = CompoundDieRolls(d.die, d.bonus + b) Base.:-(d::CompoundDieRolls, b::Integer) = CompoundDieRolls(d.die, d.bonus - b) Base.:+(d1::CompoundDieRolls, d2::CompoundDieRolls) = CompoundDieRolls(vcat(d1.die, d2.die), d1.bonus + d2.bonus) Base.:-(d1::CompoundDieRolls, d2::CompoundDieRolls) = d1 + -d2 Base.minimum(d::BaseDie{F}) where {F} = F < 0 ? -F : 1 Base.maximum(d::BaseDie{F}) where {F} = F < 0 ? -1 : F Base.minimum(d::DieRolls) = d.rolls < 0 ? d.rolls * maximum(d.die) : d.rolls Base.maximum(d::DieRolls) = d.rolls > 0 ? d.rolls * maximum(d.die) : d.rolls Base.minimum(d::CompoundDieRolls) = mapreduce(minimum, +, d.die) + d.bonus Base.maximum(d::CompoundDieRolls) = mapreduce(maximum, +, d.die) + d.bonus Base.extrema(d::AbstractDice) = (minimum(d), maximum(d)) # Pretty-printing of dice rolling objects Base.show(io::IO, d::Die{F}) where {F} = F > 0 ? print(io, "d", F) : print(io, "-1d", -F) Base.show(io::IO, d::AdvantageDie{F}) where {F} = F > 0 ? print(io, "d", F, "adv") : print(io, "-1d", -F, "adv") Base.show(io::IO, d::DisadvantageDie{F}) where {F} = F > 0 ? print(io, "d", F, "dis") : print(io, "-1d", -F, "dis") Base.show(io::IO, d::DieRolls) = print(io, d.rolls, d.die) function Base.show(io::IO, d::CompoundDieRolls) for ii in eachindex(d.die) d′ = d.die[ii] if ii == firstindex(d.die) print(io, d′) else if dierolls(d′) < 0 print(io, " - ", -d′) else print(io, " + ", d′) end end end if !iszero(d.bonus) if d.bonus < 0 print(io, " - ", -d.bonus) else print(io, " + ", d.bonus) end end end function _base_distribution!(P::Vector, die::Die) F = diebase(die) P[2:F+1] .= 1 / F return P end function _base_distribution!(P::Vector, die::AdvantageDie) F = diebase(die) R = 1:F P[2:F+1] .= (2 .* R .- 1) ./ F^2 return P end function _base_distribution!(P::Vector, die::DisadvantageDie) F = diebase(die) R = F:-1:1 P[2:F+1] .= (2 .* R .- 1) ./ F^2 return P end function distribution(d::CompoundDieRolls) # Deal with the bonus purely in terms of axis indices if d.bonus != 0 bonus = d.bonus P = distribution(d - bonus) return OffsetVector(parent(P), axes(P, 1) .+ bonus) end # Calculate the "negative" distribution if it is biased negative if abs(minimum(d)) > abs(maximum(d)) P = distribution(-d) ax = axes(P, 1) return OffsetVector(reverse!(parent(P)), -last(ax):-first(ax)) end maxN = maximum(d) # if the distribution crosses the zero line, make the "negative frequencies" # disjoint from the positive portion len = minimum(d) < 0 ? 2maxN + 1 : maxN + 1 P = Vector{Float64}(undef, len) T′ = ones(ComplexF64, (len >> 1) + 1) # real-only half-length FFT for d′ in d.die rolls = dierolls(d′) rolls == 0 && continue # skip if contributing no rolls fill!(P, 0.0) _base_distribution!(P, d′.die) if rolls < 0 # decrements to distribution occur by reflecting over 0; by periodicity, # the negative indices are second half of the vector. (See `fftshift()` and # `fftfreq()`.) reverse!(P, 2) # start at index 2 (freq == 1) to keep DC (freq == 0) term in place end # convolve with cumulative distribution T′ .*= rfft(P) .^ abs(rolls) end P .= irfft(T′, len) if minimum(d) < 0 return OffsetVector(fftshift(P), -maxN:maxN) else return OffsetVector(P, 0:maxN) end end distribution(d::AbstractDice) = distribution(1*d + 0) function pte(d::AbstractDice, value::Integer) min, max = extrema(d) value < min && return 1.0 value ≥ max && return 0.0 P = distribution(d) return sum(P[value+1:end]) end @require RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" begin using .RecipesBase @eval @recipe function f(dice::AbstractDice; xshift = 0.0) l, h = extrema(dice) P = distribution(dice) seriestype := :sticks xlims --> tuple(l ≥ 0 ? 0 : -Inf, h ≤ 0 ? 0 : Inf) # avoid suppressing zero label --> sprint(show, dice) yguide --> "probability" xguide --> "total of rolls" return ((l:h) .+ xshift, parent(P[l:h])) end end end