Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

Full parity with the R `moderndive` and `infer` packages.

- **R argument parity (infer)**: `hypothesize(med=)`/`observe(med=)` add a median point null; `prop_test()` gains `z`, `correct` (Yates), `conf_int`, and `conf_level` (now matching R's `prop.test` — chi-square by default, with a Wilson-score CI for one proportion); `rep_slice_sample(prop=, weight_by=)` and `rep_sample_n(prob=)` add fractional and weighted sampling; `generate(variables=)` chooses which column to permute; `shade_p_value`/`shade_confidence_interval` gain `fill`; `visualize(dens_color=)` sets the theoretical-curve color. (`shade_p_value` now also honors `color`.)

- **Chi-square goodness-of-fit** (closes the last `infer` vignette gap):
`specify(response=cat).hypothesize(null="point", p={level: prob, ...})` with
`generate(type="draw")` and `calculate(stat="Chisq")` now runs a one-variable
Expand Down
53 changes: 43 additions & 10 deletions moderndive/infer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,25 @@ def hypothesize(
null: str,
*,
mu: float | None = None,
med: float | None = None,
p: float | dict | None = None,
sigma: float | None = None,
) -> Hypothesis:
if null not in ("point", "independence", "paired independence"):
raise ValueError("null must be 'point', 'independence', or 'paired independence'")
return Hypothesis(spec=self, null=null, mu=mu, p=p, sigma=sigma)
return Hypothesis(spec=self, null=null, mu=mu, med=med, p=p, sigma=sigma)

def generate(
self, reps: int, type: str = "bootstrap", *, seed: int | None = None
self,
reps: int,
type: str = "bootstrap",
*,
variables: str | None = None,
seed: int | None = None,
) -> GeneratedReplicates:
return _generate(self, hypothesis=None, reps=reps, type=type, seed=seed)
return _generate(
self, hypothesis=None, reps=reps, type=type, variables=variables, seed=seed
)

def calculate(
self,
Expand Down Expand Up @@ -124,8 +132,8 @@ def assume(self, distribution: str, df: object | None = None):
return _assume(distribution, df=df)

# British-spelling alias (infer parity).
def hypothesise(self, null: str, *, mu=None, p=None, sigma=None):
return self.hypothesize(null, mu=mu, p=p, sigma=sigma)
def hypothesise(self, null: str, *, mu=None, med=None, p=None, sigma=None):
return self.hypothesize(null, mu=mu, med=med, p=p, sigma=sigma)

def fit(self) -> FitResult:
"""Fit the observed regression (ordinary least squares) for the formula."""
Expand All @@ -144,15 +152,23 @@ class Hypothesis:
spec: Specification
null: str
mu: float | None = None
med: float | None = None
p: float | dict | None = None
sigma: float | None = None

def generate(
self, reps: int, type: str | None = None, *, seed: int | None = None
self,
reps: int,
type: str | None = None,
*,
variables: str | None = None,
seed: int | None = None,
) -> GeneratedReplicates:
if type is None:
type = "bootstrap" if self.null == "point" else "permute"
return _generate(self.spec, hypothesis=self, reps=reps, type=type, seed=seed)
return _generate(
self.spec, hypothesis=self, reps=reps, type=type, variables=variables, seed=seed
)

def calculate(
self,
Expand Down Expand Up @@ -182,6 +198,7 @@ class GeneratedReplicates:
hyp_mu: float | None = None
hyp_p: float | dict | None = None
hyp_sigma: float | None = None
variables: str | None = None

# --- single-variable / two-group statistics ---------------------------
def calculate(self, stat, *, order: tuple[object, object] | None = None) -> Distribution:
Expand All @@ -201,8 +218,12 @@ def calculate(self, stat, *, order: tuple[object, object] | None = None) -> Dist
resp = resp_full * plan # randomly flip the sign of each difference
expl = None
elif self.type == "permute":
resp = resp_full
expl = None if expl_full is None else expl_full[plan]
if self.variables is not None and self.variables == spec.response:
resp = resp_full[plan] # permute the response instead of the explanatory
expl = expl_full
else:
resp = resp_full
expl = None if expl_full is None else expl_full[plan]
else: # draw
resp = plan # the simulated response array
expl = None
Expand Down Expand Up @@ -267,18 +288,22 @@ def _generate(
hypothesis: Hypothesis | None,
reps: int,
type: str,
variables: str | None = None,
seed: int | None,
) -> GeneratedReplicates:
if type == "simulate": # infer accepts "simulate" as an alias for "draw"
type = "draw"
if type not in ("bootstrap", "permute", "draw"):
raise ValueError("type must be 'bootstrap', 'permute', 'draw', or 'simulate'")
if variables is not None and variables not in (spec.response, spec.explanatory):
raise ValueError(f"variables={variables!r} must be the response or explanatory variable.")
rng = _resample.make_rng(seed)
n = spec.data.height
null = None if hypothesis is None else hypothesis.null
hyp_mu = None if hypothesis is None else hypothesis.mu
hyp_p = None if hypothesis is None else hypothesis.p
hyp_sigma = None if hypothesis is None else hypothesis.sigma
hyp_med = None if hypothesis is None else hypothesis.med
shifted = None
plans: list[np.ndarray] = []

Expand Down Expand Up @@ -313,6 +338,10 @@ def _generate(
shifted = _resample.shift_for_point_null(
spec._response_values, stat="mean", mu=hypothesis.mu, p=None
)
elif hypothesis is not None and hypothesis.null == "point" and hyp_med is not None:
shifted = _resample.shift_for_point_null(
spec._response_values, stat="median", mu=hyp_med, p=None
)
for _ in range(reps):
plans.append(rng.integers(0, n, size=n))

Expand All @@ -325,6 +354,7 @@ def _generate(
hyp_mu=hyp_mu,
hyp_p=hyp_p,
hyp_sigma=hyp_sigma,
variables=variables,
)


Expand Down Expand Up @@ -494,6 +524,7 @@ def observe(
order: tuple[object, object] | None = None,
null: str | None = None,
mu: float | None = None,
med: float | None = None,
p: float | dict | None = None,
sigma: float | None = None,
) -> ObservedStatistic:
Expand All @@ -505,7 +536,9 @@ def observe(
data, response=response, explanatory=explanatory, formula=formula, success=success
)
if null is not None:
return spec.hypothesize(null=null, mu=mu, p=p, sigma=sigma).calculate(stat, order=order)
return spec.hypothesize(null=null, mu=mu, med=med, p=p, sigma=sigma).calculate(
stat, order=order
)
return spec.calculate(stat, order=order, mu=mu, p=p, sigma=sigma)


Expand Down
32 changes: 24 additions & 8 deletions moderndive/infer/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ShadeSpec:
lower: float | None = None
upper: float | None = None
color: str | None = None
fill: str | None = None
per_term: tuple | None = None


Expand Down Expand Up @@ -182,6 +183,7 @@ def visualize(
*,
engine: str = "plotly",
method: str = "simulation",
dens_color: str | None = None,
shade_pvalue=None,
shade_ci=None,
**kwargs,
Expand All @@ -191,18 +193,19 @@ def visualize(
``method`` is ``"simulation"`` (histogram, default), ``"theoretical"`` (a
normal-approximation density curve), or ``"both"`` (histogram in density units
overlaid with the normal curve), mirroring R ``infer``'s ``visualize(method=)``.
``dens_color`` sets the theoretical-curve color (for ``"theoretical"``/``"both"``).
Pass ``shade_pvalue=``/``shade_ci=`` to shade in one call, or compose with ``+``.
"""
engine = C.resolve_engine(engine)
method = C.resolve_method(method)
if engine == "plotnine":
from . import _plotnine as P

fig = P.visualize_gg(distribution, bins, method)
fig = P.visualize_gg(distribution, bins, method, dens_color)
else:
from . import _plotly as PX

fig = PX.visualize_px(distribution, bins, method)
fig = PX.visualize_px(distribution, bins, method, dens_color)

plot = InferPlot(fig, engine)
if shade_pvalue is not None:
Expand Down Expand Up @@ -288,7 +291,9 @@ def _per_term_ci(endpoints) -> dict | None:
return None


def shade_p_value(obs_stat, direction: str, *, color: str | None = None) -> ShadeSpec:
def shade_p_value(
obs_stat, direction: str, *, color: str | None = None, fill: str | None = None
) -> ShadeSpec:
"""A p-value shading spec; add it to a ``visualize()`` plot with ``+``.

``direction`` ∈ {right/greater, left/less, two-sided}. For a faceted
Expand All @@ -298,12 +303,20 @@ def shade_p_value(obs_stat, direction: str, *, color: str | None = None) -> Shad
per = _per_term_obs(obs_stat)
if per is not None:
return ShadeSpec(
kind="p_value", direction=direction, color=color, per_term=tuple(sorted(per.items()))
kind="p_value",
direction=direction,
color=color,
fill=fill,
per_term=tuple(sorted(per.items())),
)
return ShadeSpec(kind="p_value", obs_stat=float(obs_stat), direction=direction, color=color)
return ShadeSpec(
kind="p_value", obs_stat=float(obs_stat), direction=direction, color=color, fill=fill
)


def shade_confidence_interval(endpoints, color: str | None = None) -> ShadeSpec:
def shade_confidence_interval(
endpoints, color: str | None = None, fill: str | None = None
) -> ShadeSpec:
"""A confidence-interval shading spec; add it to a ``visualize()`` plot with ``+``.

``endpoints`` is a CI DataFrame (``lower_ci``/``upper_ci``) or a ``(lower, upper)``
Expand All @@ -313,7 +326,10 @@ def shade_confidence_interval(endpoints, color: str | None = None) -> ShadeSpec:
per = _per_term_ci(endpoints)
if per is not None:
return ShadeSpec(
kind="confidence_interval", color=color, per_term=tuple(sorted(per.items()))
kind="confidence_interval",
color=color,
fill=fill,
per_term=tuple(sorted(per.items())),
)
lower, upper = C.ci_endpoints(endpoints)
return ShadeSpec(kind="confidence_interval", lower=lower, upper=upper, color=color)
return ShadeSpec(kind="confidence_interval", lower=lower, upper=upper, color=color, fill=fill)
41 changes: 24 additions & 17 deletions moderndive/infer/viz/_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,24 @@ def _layout(fig, title: str, xlab: str, ylab: str):
return fig


def density_curve_px(x, density, title: str, xlab: str = "statistic"):
def density_curve_px(x, density, title: str, xlab: str = "statistic", color: str | None = None):
"""A standalone theoretical density curve."""
go = _go()
fig = go.Figure(go.Scatter(x=x, y=density, mode="lines", line={"color": C._OBS_COLOR}))
fig = go.Figure(go.Scatter(x=x, y=density, mode="lines", line={"color": color or C._OBS_COLOR}))
return _layout(fig, title, xlab, "density")


def visualize_px(distribution, bins: int, method: str):
def visualize_px(distribution, bins: int, method: str, dens_color: str | None = None):
"""Histogram of simulated statistics, optionally overlaid with a normal curve."""
go = _go()
values = C.stat_values(distribution)
xlab = C.stat_label(distribution.stat)
title = C.dist_title(distribution.null)
curve_color = dens_color or C._OBS_COLOR

if method == "theoretical":
x, dens = C.normal_overlay(values)
return density_curve_px(x, dens, "Theoretical Distribution", xlab)
return density_curve_px(x, dens, "Theoretical Distribution", xlab, curve_color)

histnorm = "probability density" if method == "both" else None
fig = go.Figure(
Expand All @@ -54,7 +55,7 @@ def visualize_px(distribution, bins: int, method: str):
)
if method == "both":
x, dens = C.normal_overlay(values)
fig.add_scatter(x=x, y=dens, mode="lines", line={"color": C._OBS_COLOR})
fig.add_scatter(x=x, y=dens, mode="lines", line={"color": curve_color})
return _layout(fig, title, xlab, "density" if method == "both" else "count")


Expand Down Expand Up @@ -112,24 +113,27 @@ def apply_shade_px(fig, spec):
lo, hi = _data_range(out)

if spec.kind == "p_value":
lc = spec.color or C._OBS_COLOR
fc = spec.fill or spec.color or C._OBS_COLOR
vlines, rects = C.pvalue_regions(spec.obs_stat, spec.direction)
for x, dashed in vlines:
out.add_vline(
x=x, line={"color": C._OBS_COLOR, "width": 2, "dash": "dash" if dashed else "solid"}
x=x, line={"color": lc, "width": 2, "dash": "dash" if dashed else "solid"}
)
for xmin, xmax in rects:
out.add_vrect(
x0=_clip(xmin, lo, hi),
x1=_clip(xmax, lo, hi),
fillcolor=C._OBS_COLOR,
fillcolor=fc,
opacity=0.3,
line_width=0,
)
else: # confidence_interval
color = spec.color or C._SHADE_COLOR
out.add_vrect(x0=spec.lower, x1=spec.upper, fillcolor=color, opacity=0.3, line_width=0)
out.add_vline(x=spec.lower, line={"color": color, "width": 2})
out.add_vline(x=spec.upper, line={"color": color, "width": 2})
lc = spec.color or C._SHADE_COLOR
fc = spec.fill or spec.color or C._SHADE_COLOR
out.add_vrect(x0=spec.lower, x1=spec.upper, fillcolor=fc, opacity=0.3, line_width=0)
out.add_vline(x=spec.lower, line={"color": lc, "width": 2})
out.add_vline(x=spec.upper, line={"color": lc, "width": 2})
return out


Expand Down Expand Up @@ -157,6 +161,8 @@ def apply_fit_shade_px(fig, spec, terms):
per = dict(spec.per_term)

if spec.kind == "p_value":
lc = spec.color or C._OBS_COLOR
fc = spec.fill or spec.color or C._OBS_COLOR
for term, obs in per.items():
col = term_to_col.get(term)
if col is None:
Expand All @@ -166,29 +172,30 @@ def apply_fit_shade_px(fig, spec, terms):
for x, dashed in vlines:
out.add_vline(
x=x,
line={"color": C._OBS_COLOR, "width": 2, "dash": "dash" if dashed else "solid"},
line={"color": lc, "width": 2, "dash": "dash" if dashed else "solid"},
row=1,
col=col,
)
for xmin, xmax in rects:
out.add_vrect(
x0=_clip(xmin, lo, hi),
x1=_clip(xmax, lo, hi),
fillcolor=C._OBS_COLOR,
fillcolor=fc,
opacity=0.3,
line_width=0,
row=1,
col=col,
)
else: # confidence_interval
color = spec.color or C._SHADE_COLOR
lc = spec.color or C._SHADE_COLOR
fc = spec.fill or spec.color or C._SHADE_COLOR
for term, (lower, upper) in per.items():
col = term_to_col.get(term)
if col is None:
continue
out.add_vrect(
x0=lower, x1=upper, fillcolor=color, opacity=0.3, line_width=0, row=1, col=col
x0=lower, x1=upper, fillcolor=fc, opacity=0.3, line_width=0, row=1, col=col
)
out.add_vline(x=lower, line={"color": color, "width": 2}, row=1, col=col)
out.add_vline(x=upper, line={"color": color, "width": 2}, row=1, col=col)
out.add_vline(x=lower, line={"color": lc, "width": 2}, row=1, col=col)
out.add_vline(x=upper, line={"color": lc, "width": 2}, row=1, col=col)
return out
Loading