diff --git a/CHANGELOG.md b/CHANGELOG.md index 96afa73..35297a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ Full parity with the R `moderndive` and `infer` packages. +- **R argument parity (infer)**: `hypothesize(med=)`/`observe(med=)` add a median point null; `prop_test()` gains `z`, `correct` (Yates), `conf_int`, and `conf_level` (now matching R's `prop.test` — chi-square by default, with a Wilson-score CI for one proportion); `rep_slice_sample(prop=, weight_by=)` and `rep_sample_n(prob=)` add fractional and weighted sampling; `generate(variables=)` chooses which column to permute; `shade_p_value`/`shade_confidence_interval` gain `fill`; `visualize(dens_color=)` sets the theoretical-curve color. (`shade_p_value` now also honors `color`.) + - **Chi-square goodness-of-fit** (closes the last `infer` vignette gap): `specify(response=cat).hypothesize(null="point", p={level: prob, ...})` with `generate(type="draw")` and `calculate(stat="Chisq")` now runs a one-variable diff --git a/moderndive/infer/core.py b/moderndive/infer/core.py index d04a35b..0e9a457 100644 --- a/moderndive/infer/core.py +++ b/moderndive/infer/core.py @@ -82,17 +82,25 @@ def hypothesize( null: str, *, mu: float | None = None, + med: float | None = None, p: float | dict | None = None, sigma: float | None = None, ) -> Hypothesis: if null not in ("point", "independence", "paired independence"): raise ValueError("null must be 'point', 'independence', or 'paired independence'") - return Hypothesis(spec=self, null=null, mu=mu, p=p, sigma=sigma) + return Hypothesis(spec=self, null=null, mu=mu, med=med, p=p, sigma=sigma) def generate( - self, reps: int, type: str = "bootstrap", *, seed: int | None = None + self, + reps: int, + type: str = "bootstrap", + *, + variables: str | None = None, + seed: int | None = None, ) -> GeneratedReplicates: - return _generate(self, hypothesis=None, reps=reps, type=type, seed=seed) + return _generate( + self, hypothesis=None, reps=reps, type=type, variables=variables, seed=seed + ) def calculate( self, @@ -124,8 +132,8 @@ def assume(self, distribution: str, df: object | None = None): return _assume(distribution, df=df) # British-spelling alias (infer parity). - def hypothesise(self, null: str, *, mu=None, p=None, sigma=None): - return self.hypothesize(null, mu=mu, p=p, sigma=sigma) + def hypothesise(self, null: str, *, mu=None, med=None, p=None, sigma=None): + return self.hypothesize(null, mu=mu, med=med, p=p, sigma=sigma) def fit(self) -> FitResult: """Fit the observed regression (ordinary least squares) for the formula.""" @@ -144,15 +152,23 @@ class Hypothesis: spec: Specification null: str mu: float | None = None + med: float | None = None p: float | dict | None = None sigma: float | None = None def generate( - self, reps: int, type: str | None = None, *, seed: int | None = None + self, + reps: int, + type: str | None = None, + *, + variables: str | None = None, + seed: int | None = None, ) -> GeneratedReplicates: if type is None: type = "bootstrap" if self.null == "point" else "permute" - return _generate(self.spec, hypothesis=self, reps=reps, type=type, seed=seed) + return _generate( + self.spec, hypothesis=self, reps=reps, type=type, variables=variables, seed=seed + ) def calculate( self, @@ -182,6 +198,7 @@ class GeneratedReplicates: hyp_mu: float | None = None hyp_p: float | dict | None = None hyp_sigma: float | None = None + variables: str | None = None # --- single-variable / two-group statistics --------------------------- def calculate(self, stat, *, order: tuple[object, object] | None = None) -> Distribution: @@ -201,8 +218,12 @@ def calculate(self, stat, *, order: tuple[object, object] | None = None) -> Dist resp = resp_full * plan # randomly flip the sign of each difference expl = None elif self.type == "permute": - resp = resp_full - expl = None if expl_full is None else expl_full[plan] + if self.variables is not None and self.variables == spec.response: + resp = resp_full[plan] # permute the response instead of the explanatory + expl = expl_full + else: + resp = resp_full + expl = None if expl_full is None else expl_full[plan] else: # draw resp = plan # the simulated response array expl = None @@ -267,18 +288,22 @@ def _generate( hypothesis: Hypothesis | None, reps: int, type: str, + variables: str | None = None, seed: int | None, ) -> GeneratedReplicates: if type == "simulate": # infer accepts "simulate" as an alias for "draw" type = "draw" if type not in ("bootstrap", "permute", "draw"): raise ValueError("type must be 'bootstrap', 'permute', 'draw', or 'simulate'") + if variables is not None and variables not in (spec.response, spec.explanatory): + raise ValueError(f"variables={variables!r} must be the response or explanatory variable.") rng = _resample.make_rng(seed) n = spec.data.height null = None if hypothesis is None else hypothesis.null hyp_mu = None if hypothesis is None else hypothesis.mu hyp_p = None if hypothesis is None else hypothesis.p hyp_sigma = None if hypothesis is None else hypothesis.sigma + hyp_med = None if hypothesis is None else hypothesis.med shifted = None plans: list[np.ndarray] = [] @@ -313,6 +338,10 @@ def _generate( shifted = _resample.shift_for_point_null( spec._response_values, stat="mean", mu=hypothesis.mu, p=None ) + elif hypothesis is not None and hypothesis.null == "point" and hyp_med is not None: + shifted = _resample.shift_for_point_null( + spec._response_values, stat="median", mu=hyp_med, p=None + ) for _ in range(reps): plans.append(rng.integers(0, n, size=n)) @@ -325,6 +354,7 @@ def _generate( hyp_mu=hyp_mu, hyp_p=hyp_p, hyp_sigma=hyp_sigma, + variables=variables, ) @@ -494,6 +524,7 @@ def observe( order: tuple[object, object] | None = None, null: str | None = None, mu: float | None = None, + med: float | None = None, p: float | dict | None = None, sigma: float | None = None, ) -> ObservedStatistic: @@ -505,7 +536,9 @@ def observe( data, response=response, explanatory=explanatory, formula=formula, success=success ) if null is not None: - return spec.hypothesize(null=null, mu=mu, p=p, sigma=sigma).calculate(stat, order=order) + return spec.hypothesize(null=null, mu=mu, med=med, p=p, sigma=sigma).calculate( + stat, order=order + ) return spec.calculate(stat, order=order, mu=mu, p=p, sigma=sigma) diff --git a/moderndive/infer/viz/__init__.py b/moderndive/infer/viz/__init__.py index af509cf..bc3a12c 100644 --- a/moderndive/infer/viz/__init__.py +++ b/moderndive/infer/viz/__init__.py @@ -55,6 +55,7 @@ class ShadeSpec: lower: float | None = None upper: float | None = None color: str | None = None + fill: str | None = None per_term: tuple | None = None @@ -182,6 +183,7 @@ def visualize( *, engine: str = "plotly", method: str = "simulation", + dens_color: str | None = None, shade_pvalue=None, shade_ci=None, **kwargs, @@ -191,6 +193,7 @@ def visualize( ``method`` is ``"simulation"`` (histogram, default), ``"theoretical"`` (a normal-approximation density curve), or ``"both"`` (histogram in density units overlaid with the normal curve), mirroring R ``infer``'s ``visualize(method=)``. + ``dens_color`` sets the theoretical-curve color (for ``"theoretical"``/``"both"``). Pass ``shade_pvalue=``/``shade_ci=`` to shade in one call, or compose with ``+``. """ engine = C.resolve_engine(engine) @@ -198,11 +201,11 @@ def visualize( if engine == "plotnine": from . import _plotnine as P - fig = P.visualize_gg(distribution, bins, method) + fig = P.visualize_gg(distribution, bins, method, dens_color) else: from . import _plotly as PX - fig = PX.visualize_px(distribution, bins, method) + fig = PX.visualize_px(distribution, bins, method, dens_color) plot = InferPlot(fig, engine) if shade_pvalue is not None: @@ -288,7 +291,9 @@ def _per_term_ci(endpoints) -> dict | None: return None -def shade_p_value(obs_stat, direction: str, *, color: str | None = None) -> ShadeSpec: +def shade_p_value( + obs_stat, direction: str, *, color: str | None = None, fill: str | None = None +) -> ShadeSpec: """A p-value shading spec; add it to a ``visualize()`` plot with ``+``. ``direction`` ∈ {right/greater, left/less, two-sided}. For a faceted @@ -298,12 +303,20 @@ def shade_p_value(obs_stat, direction: str, *, color: str | None = None) -> Shad per = _per_term_obs(obs_stat) if per is not None: return ShadeSpec( - kind="p_value", direction=direction, color=color, per_term=tuple(sorted(per.items())) + kind="p_value", + direction=direction, + color=color, + fill=fill, + per_term=tuple(sorted(per.items())), ) - return ShadeSpec(kind="p_value", obs_stat=float(obs_stat), direction=direction, color=color) + return ShadeSpec( + kind="p_value", obs_stat=float(obs_stat), direction=direction, color=color, fill=fill + ) -def shade_confidence_interval(endpoints, color: str | None = None) -> ShadeSpec: +def shade_confidence_interval( + endpoints, color: str | None = None, fill: str | None = None +) -> ShadeSpec: """A confidence-interval shading spec; add it to a ``visualize()`` plot with ``+``. ``endpoints`` is a CI DataFrame (``lower_ci``/``upper_ci``) or a ``(lower, upper)`` @@ -313,7 +326,10 @@ def shade_confidence_interval(endpoints, color: str | None = None) -> ShadeSpec: per = _per_term_ci(endpoints) if per is not None: return ShadeSpec( - kind="confidence_interval", color=color, per_term=tuple(sorted(per.items())) + kind="confidence_interval", + color=color, + fill=fill, + per_term=tuple(sorted(per.items())), ) lower, upper = C.ci_endpoints(endpoints) - return ShadeSpec(kind="confidence_interval", lower=lower, upper=upper, color=color) + return ShadeSpec(kind="confidence_interval", lower=lower, upper=upper, color=color, fill=fill) diff --git a/moderndive/infer/viz/_plotly.py b/moderndive/infer/viz/_plotly.py index 8cca4bc..1d993b6 100644 --- a/moderndive/infer/viz/_plotly.py +++ b/moderndive/infer/viz/_plotly.py @@ -25,23 +25,24 @@ def _layout(fig, title: str, xlab: str, ylab: str): return fig -def density_curve_px(x, density, title: str, xlab: str = "statistic"): +def density_curve_px(x, density, title: str, xlab: str = "statistic", color: str | None = None): """A standalone theoretical density curve.""" go = _go() - fig = go.Figure(go.Scatter(x=x, y=density, mode="lines", line={"color": C._OBS_COLOR})) + fig = go.Figure(go.Scatter(x=x, y=density, mode="lines", line={"color": color or C._OBS_COLOR})) return _layout(fig, title, xlab, "density") -def visualize_px(distribution, bins: int, method: str): +def visualize_px(distribution, bins: int, method: str, dens_color: str | None = None): """Histogram of simulated statistics, optionally overlaid with a normal curve.""" go = _go() values = C.stat_values(distribution) xlab = C.stat_label(distribution.stat) title = C.dist_title(distribution.null) + curve_color = dens_color or C._OBS_COLOR if method == "theoretical": x, dens = C.normal_overlay(values) - return density_curve_px(x, dens, "Theoretical Distribution", xlab) + return density_curve_px(x, dens, "Theoretical Distribution", xlab, curve_color) histnorm = "probability density" if method == "both" else None fig = go.Figure( @@ -54,7 +55,7 @@ def visualize_px(distribution, bins: int, method: str): ) if method == "both": x, dens = C.normal_overlay(values) - fig.add_scatter(x=x, y=dens, mode="lines", line={"color": C._OBS_COLOR}) + fig.add_scatter(x=x, y=dens, mode="lines", line={"color": curve_color}) return _layout(fig, title, xlab, "density" if method == "both" else "count") @@ -112,24 +113,27 @@ def apply_shade_px(fig, spec): lo, hi = _data_range(out) if spec.kind == "p_value": + lc = spec.color or C._OBS_COLOR + fc = spec.fill or spec.color or C._OBS_COLOR vlines, rects = C.pvalue_regions(spec.obs_stat, spec.direction) for x, dashed in vlines: out.add_vline( - x=x, line={"color": C._OBS_COLOR, "width": 2, "dash": "dash" if dashed else "solid"} + x=x, line={"color": lc, "width": 2, "dash": "dash" if dashed else "solid"} ) for xmin, xmax in rects: out.add_vrect( x0=_clip(xmin, lo, hi), x1=_clip(xmax, lo, hi), - fillcolor=C._OBS_COLOR, + fillcolor=fc, opacity=0.3, line_width=0, ) else: # confidence_interval - color = spec.color or C._SHADE_COLOR - out.add_vrect(x0=spec.lower, x1=spec.upper, fillcolor=color, opacity=0.3, line_width=0) - out.add_vline(x=spec.lower, line={"color": color, "width": 2}) - out.add_vline(x=spec.upper, line={"color": color, "width": 2}) + lc = spec.color or C._SHADE_COLOR + fc = spec.fill or spec.color or C._SHADE_COLOR + out.add_vrect(x0=spec.lower, x1=spec.upper, fillcolor=fc, opacity=0.3, line_width=0) + out.add_vline(x=spec.lower, line={"color": lc, "width": 2}) + out.add_vline(x=spec.upper, line={"color": lc, "width": 2}) return out @@ -157,6 +161,8 @@ def apply_fit_shade_px(fig, spec, terms): per = dict(spec.per_term) if spec.kind == "p_value": + lc = spec.color or C._OBS_COLOR + fc = spec.fill or spec.color or C._OBS_COLOR for term, obs in per.items(): col = term_to_col.get(term) if col is None: @@ -166,7 +172,7 @@ def apply_fit_shade_px(fig, spec, terms): for x, dashed in vlines: out.add_vline( x=x, - line={"color": C._OBS_COLOR, "width": 2, "dash": "dash" if dashed else "solid"}, + line={"color": lc, "width": 2, "dash": "dash" if dashed else "solid"}, row=1, col=col, ) @@ -174,21 +180,22 @@ def apply_fit_shade_px(fig, spec, terms): out.add_vrect( x0=_clip(xmin, lo, hi), x1=_clip(xmax, lo, hi), - fillcolor=C._OBS_COLOR, + fillcolor=fc, opacity=0.3, line_width=0, row=1, col=col, ) else: # confidence_interval - color = spec.color or C._SHADE_COLOR + lc = spec.color or C._SHADE_COLOR + fc = spec.fill or spec.color or C._SHADE_COLOR for term, (lower, upper) in per.items(): col = term_to_col.get(term) if col is None: continue out.add_vrect( - x0=lower, x1=upper, fillcolor=color, opacity=0.3, line_width=0, row=1, col=col + x0=lower, x1=upper, fillcolor=fc, opacity=0.3, line_width=0, row=1, col=col ) - out.add_vline(x=lower, line={"color": color, "width": 2}, row=1, col=col) - out.add_vline(x=upper, line={"color": color, "width": 2}, row=1, col=col) + out.add_vline(x=lower, line={"color": lc, "width": 2}, row=1, col=col) + out.add_vline(x=upper, line={"color": lc, "width": 2}, row=1, col=col) return out diff --git a/moderndive/infer/viz/_plotnine.py b/moderndive/infer/viz/_plotnine.py index 3f3c268..9a4b7d0 100644 --- a/moderndive/infer/viz/_plotnine.py +++ b/moderndive/infer/viz/_plotnine.py @@ -24,26 +24,27 @@ def _full_height_rect(xmin: float, xmax: float, fill: str): return annotate("rect", xmin=xmin, xmax=xmax, ymin=-C._INF, ymax=C._INF, alpha=0.3, fill=fill) -def density_curve_gg(x, density, title: str, xlab: str = "statistic"): +def density_curve_gg(x, density, title: str, xlab: str = "statistic", color: str | None = None): """A standalone theoretical density curve.""" pdf = pd.DataFrame({"x": x, "density": density}) return ( ggplot(pdf, aes(x="x", y="density")) - + geom_line(color=C._OBS_COLOR, size=1.0) + + geom_line(color=color or C._OBS_COLOR, size=1.0) + labs(x=xlab, y="density", title=title) + theme_light() ) -def visualize_gg(distribution, bins: int, method: str): +def visualize_gg(distribution, bins: int, method: str, dens_color: str | None = None): """Histogram of simulated statistics, optionally overlaid with a normal curve.""" values = C.stat_values(distribution) xlab = C.stat_label(distribution.stat) title = C.dist_title(distribution.null) + curve_color = dens_color or C._OBS_COLOR if method == "theoretical": x, dens = C.normal_overlay(values) - return density_curve_gg(x, dens, "Theoretical Distribution", xlab) + return density_curve_gg(x, dens, "Theoretical Distribution", xlab, curve_color) pdf = pd.DataFrame({"stat": values}) if method == "both": @@ -56,7 +57,7 @@ def visualize_gg(distribution, bins: int, method: str): + geom_line( aes(x="stat", y="density"), data=pd.DataFrame({"stat": x, "density": dens}), - color=C._OBS_COLOR, + color=curve_color, size=1.0, ) + labs(x=xlab, y="density", title=title) @@ -86,22 +87,25 @@ def visualize_fit_gg(fit, bins: int): def shade_pvalue_layers(spec) -> list: """plotnine layers shading the p-value tail(s) and marking the observed stat.""" vlines, rects = C.pvalue_regions(spec.obs_stat, spec.direction) + lc = spec.color or C._OBS_COLOR + fc = spec.fill or spec.color or C._OBS_COLOR layers: list = [] for x, dashed in vlines: extra = {"linetype": "dashed"} if dashed else {} - layers.append(geom_vline(xintercept=x, color=C._OBS_COLOR, size=1.0, **extra)) + layers.append(geom_vline(xintercept=x, color=lc, size=1.0, **extra)) for xmin, xmax in rects: - layers.append(_full_height_rect(xmin, xmax, C._OBS_COLOR)) + layers.append(_full_height_rect(xmin, xmax, fc)) return layers def shade_ci_layers(spec) -> list: """plotnine layers shading the confidence interval between its endpoints.""" - color = spec.color or C._SHADE_COLOR + lc = spec.color or C._SHADE_COLOR + fc = spec.fill or spec.color or C._SHADE_COLOR return [ - _full_height_rect(spec.lower, spec.upper, color), - geom_vline(xintercept=spec.lower, color=color, size=1.0), - geom_vline(xintercept=spec.upper, color=color, size=1.0), + _full_height_rect(spec.lower, spec.upper, fc), + geom_vline(xintercept=spec.lower, color=lc, size=1.0), + geom_vline(xintercept=spec.upper, color=lc, size=1.0), ] @@ -149,14 +153,17 @@ def apply_fit_shade_gg(gg, spec, terms): (dashed if is_dashed else solid).append({"term": term, "x": x}) for xmin, xmax in term_rects: rects.append({"term": term, "xmin": xmin, "xmax": xmax}) + lc = spec.color or C._OBS_COLOR + fc = spec.fill or spec.color or C._OBS_COLOR if solid: - layers.append(_facet_vlines(solid, C._OBS_COLOR, dashed=False)) + layers.append(_facet_vlines(solid, lc, dashed=False)) if dashed: - layers.append(_facet_vlines(dashed, C._OBS_COLOR, dashed=True)) + layers.append(_facet_vlines(dashed, lc, dashed=True)) if rects: - layers.append(_facet_rect(rects, C._OBS_COLOR)) + layers.append(_facet_rect(rects, fc)) else: - color = spec.color or C._SHADE_COLOR + lc = spec.color or C._SHADE_COLOR + fc = spec.fill or spec.color or C._SHADE_COLOR rects, edges = [], [] for term, (lower, upper) in per.items(): if term not in terms: @@ -165,7 +172,7 @@ def apply_fit_shade_gg(gg, spec, terms): edges.append({"term": term, "x": lower}) edges.append({"term": term, "x": upper}) if rects: - layers.append(_facet_rect(rects, color)) + layers.append(_facet_rect(rects, fc)) if edges: - layers.append(_facet_vlines(edges, color, dashed=False)) + layers.append(_facet_vlines(edges, lc, dashed=False)) return gg + layers diff --git a/moderndive/infer/wrappers.py b/moderndive/infer/wrappers.py index 3715484..a1de893 100644 --- a/moderndive/infer/wrappers.py +++ b/moderndive/infer/wrappers.py @@ -76,6 +76,26 @@ def t_stat(data: pl.DataFrame, **kwargs) -> float: return float(t_test(data, **kwargs)["statistic"][0]) +def _wilson_cc(x: int, n: int, conf_level: float, correct: bool) -> tuple[float, float]: + """Wilson score interval for one proportion (R prop.test's CI; cc = Yates).""" + from scipy import stats + + z = float(stats.norm.ppf((1 + conf_level) / 2)) + phat = x / n + if correct: + lo = ( + 2 * x + z**2 - 1 - z * np.sqrt(z**2 - 2 - 1 / n + 4 * phat * (n * (1 - phat) + 1)) + ) / (2 * (n + z**2)) + hi = ( + 2 * x + z**2 + 1 + z * np.sqrt(z**2 + 2 - 1 / n + 4 * phat * (n * (1 - phat) - 1)) + ) / (2 * (n + z**2)) + return max(0.0, lo), min(1.0, hi) + denom = 1 + z**2 / n + center = (phat + z**2 / (2 * n)) / denom + half = z * np.sqrt(phat * (1 - phat) / n + z**2 / (4 * n**2)) / denom + return center - half, center + half + + def prop_test( data: pl.DataFrame, *, @@ -86,19 +106,34 @@ def prop_test( order: tuple[object, object] | None = None, p: float | None = None, alternative: str = "two-sided", + z: bool = False, + correct: bool = True, + conf_int: bool = True, + conf_level: float = 0.95, ) -> pl.DataFrame: - """One- or two-proportion z-test (normal approximation), tidy output.""" + """Tidy one- or two-proportion test, mirroring R ``infer::prop_test``. + + By default reports the **chi-square** statistic (like R's ``prop.test``) with a + ``chisq_df`` column; pass ``z=True`` for the signed **z** statistic instead. + ``correct`` applies Yates' continuity correction. With ``conf_int=True`` + (default) the output includes a ``conf_level`` confidence interval — on the + proportion (one-sample) or on the difference in proportions (two-sample). + """ from scipy import stats resp, expl = _resolve(formula, response, explanatory) + if expl is None: col = data[resp].drop_nulls() n = col.len() x = int((col == success).sum()) - p0 = 0.5 if p is None else p phat = x / n - se = np.sqrt(p0 * (1 - p0) / n) - z = (phat - p0) / se + p0 = 0.5 if p is None else p + diff = phat - p0 + cc = min(0.5 / n, abs(diff)) if correct else 0.0 + zstat = np.sign(diff) * (abs(diff) - cc) / np.sqrt(p0 * (1 - p0) / n) + estimate = phat + ci_bounds = _wilson_cc(x, n, conf_level, correct) # R uses Wilson for one proportion else: if order is None: raise ValueError("two-proportion prop_test requires order=(group1, group2)") @@ -108,16 +143,35 @@ def prop_test( b = sub.filter(pl.col(expl) == g2)[resp] xa, xb = int((a == success).sum()), int((b == success).sum()) na, nb = a.len(), b.len() + pa, pb = xa / na, xb / nb + diff = pa - pb ppool = (xa + xb) / (na + nb) - se = np.sqrt(ppool * (1 - ppool) * (1 / na + 1 / nb)) - z = (xa / na - xb / nb) / se + cc = min(0.5 * (1 / na + 1 / nb), abs(diff)) if correct else 0.0 + zstat = np.sign(diff) * (abs(diff) - cc) / np.sqrt(ppool * (1 - ppool) * (1 / na + 1 / nb)) + se_est = np.sqrt(pa * (1 - pa) / na + pb * (1 - pb) / nb) + estimate = diff + crit = float(stats.norm.ppf(1 - (1 - conf_level) / 2)) + half = crit * se_est + cc # R's prop.test widens the 2-sample diff CI by the correction + ci_bounds = (diff - half, diff + half) + if alternative in _GREATER: - pval = float(stats.norm.sf(z)) + pval = float(stats.norm.sf(zstat)) elif alternative in _LESS: - pval = float(stats.norm.cdf(z)) + pval = float(stats.norm.cdf(zstat)) else: - pval = float(2 * stats.norm.sf(abs(z))) - return pl.DataFrame({"statistic": [float(z)], "p_value": [pval], "alternative": [alternative]}) + pval = float(2 * stats.norm.sf(abs(zstat))) + + statistic = float(zstat) if z else float(zstat**2) + out = {"statistic": [statistic]} + if not z: + out["chisq_df"] = [1] + out["p_value"] = [pval] + out["estimate"] = [float(estimate)] + out["alternative"] = [alternative] + if conf_int: + out["lower_ci"] = [float(ci_bounds[0])] + out["upper_ci"] = [float(ci_bounds[1])] + return pl.DataFrame(out) def chisq_test( diff --git a/moderndive/sampling.py b/moderndive/sampling.py index 098cc9e..8b4c9e6 100644 --- a/moderndive/sampling.py +++ b/moderndive/sampling.py @@ -1,9 +1,9 @@ """Repeated sampling helpers for the sampling activities (Chapter 7). Mirrors the R ``moderndive`` functions ``rep_slice_sample()`` / -``rep_sample_n()``: draw ``reps`` samples of size ``n`` from a data frame and -stack them into one long data frame with a ``replicate`` column (1..reps), so a -grouped summary computes one statistic per virtual sample. +``rep_sample_n()``: draw ``reps`` samples from a data frame and stack them into +one long data frame with a ``replicate`` column (1..reps), so a grouped summary +computes one statistic per virtual sample. """ from __future__ import annotations @@ -11,47 +11,82 @@ import numpy as np import polars as pl +from ._messaging import helpful_error + __all__ = ["rep_slice_sample", "rep_sample_n"] +def _weights(data: pl.DataFrame, weight_by) -> np.ndarray | None: + """Normalize ``weight_by`` (a column name or a sequence) into probabilities.""" + if weight_by is None: + return None + w = data[weight_by].to_numpy() if isinstance(weight_by, str) else np.asarray(weight_by, float) + total = w.sum() + if total <= 0: + raise ValueError(helpful_error("weight_by must contain positive weights that sum to > 0.")) + return w / total + + def rep_slice_sample( data: pl.DataFrame, - n: int, + n: int | None = None, + *, + prop: float | None = None, reps: int = 1, replace: bool = False, + weight_by=None, seed: int | None = None, ) -> pl.DataFrame: - """Take ``reps`` samples of size ``n`` from ``data``. + """Take ``reps`` samples from ``data``. - Returns a polars DataFrame with a leading ``replicate`` column identifying - which sample each row belongs to. Set ``replace=True`` for sampling with - replacement (e.g. bootstrap-style). Pass ``seed`` for reproducibility. + Give the sample size as either ``n`` (a count) or ``prop`` (a fraction of the + rows, e.g. ``prop=0.5``). Returns a polars DataFrame with a leading + ``replicate`` column identifying which sample each row belongs to. Set + ``replace=True`` for sampling with replacement (bootstrap-style). ``weight_by`` + gives unequal selection probabilities — a column name or a sequence of + weights. Pass ``seed`` for reproducibility. """ - rng = np.random.default_rng(seed) + if (n is None) == (prop is None): + raise ValueError( + helpful_error( + "Specify exactly one of n= (a count) or prop= (a fraction).", + "e.g. rep_slice_sample(df, n=50) or rep_slice_sample(df, prop=0.5).", + ) + ) n_rows = data.height - if not replace and n > n_rows: - raise ValueError(f"cannot take a sample of size {n} without replacement from {n_rows} rows") + size = n if n is not None else int(round(prop * n_rows)) + if not replace and size > n_rows: + raise ValueError( + f"cannot take a sample of size {size} without replacement from {n_rows} rows" + ) + probs = _weights(data, weight_by) + rng = np.random.default_rng(seed) samples = [] for replicate in range(1, reps + 1): - idx = rng.choice(n_rows, size=n, replace=replace) + idx = rng.choice(n_rows, size=size, replace=replace, p=probs) sample = data[idx.tolist()].with_columns( pl.lit(replicate, dtype=pl.Int64).alias("replicate") ) samples.append(sample) combined = pl.concat(samples) - # Put `replicate` first. return combined.select(["replicate", *data.columns]) -# The older moderndive name is a thin alias with the same behavior. def rep_sample_n( data: pl.DataFrame, n: int, + *, reps: int = 1, replace: bool = False, + prob=None, seed: int | None = None, ) -> pl.DataFrame: - """Alias for :func:`rep_slice_sample` (older moderndive name).""" - return rep_slice_sample(data, n=n, reps=reps, replace=replace, seed=seed) + """Take ``reps`` samples of size ``n`` (older moderndive name). + + Like :func:`rep_slice_sample`, but the sample size is always the count ``n`` + and unequal selection weights are passed as ``prob`` (a column name or a + sequence), matching the R ``rep_sample_n`` signature. + """ + return rep_slice_sample(data, n=n, reps=reps, replace=replace, weight_by=prob, seed=seed) diff --git a/tests/test_arg_parity_infer.py b/tests/test_arg_parity_infer.py new file mode 100644 index 0000000..af774cd --- /dev/null +++ b/tests/test_arg_parity_infer.py @@ -0,0 +1,233 @@ +"""Tests for the infer R argument-parity additions: + +- hypothesize/observe(med=) — median point null +- prop_test(z=, correct=, conf_int=, conf_level=) — validated vs R's prop.test +- rep_slice_sample(prop=, weight_by=) / rep_sample_n(prob=) +- generate(variables=) +- shade_p_value/shade_confidence_interval(fill=), visualize(dens_color=) +""" + +from __future__ import annotations + +import matplotlib + +matplotlib.use("Agg") + +import numpy as np +import plotly.graph_objects as go +import polars as pl +import pytest +from plotnine import ggplot + +import moderndive as md +from moderndive import ( + get_p_value, + observe, + prop_test, + rep_sample_n, + rep_slice_sample, + shade_confidence_interval, + shade_p_value, + visualize, +) + +# ============================ median point null ========================= + + +def test_median_point_null_centers_and_pvalue(): + rng = np.random.default_rng(0) + df = pl.DataFrame({"x": rng.normal(10.0, 3.0, 300)}) + null = ( + df.specify(response="x") + .hypothesize(null="point", med=8.0) + .generate(reps=500, type="bootstrap", seed=1) + .calculate(stat="median") + ) + assert float(null.stats.mean()) == pytest.approx(8.0, abs=0.5) + obs = observe(df, response="x", stat="median") + pv = float(get_p_value(null, obs_stat=obs, direction="two-sided")["p_value"][0]) + assert 0.0 <= pv <= 1.0 + + +def test_observe_med_via_hypothesise_alias(): + df = pl.DataFrame({"x": [1.0, 2, 3, 4, 5]}) + h = df.specify(response="x").hypothesise(null="point", med=3.0) + assert h.med == 3.0 + + +# ============================ prop_test (vs R prop.test) ================ + + +def _two_group(): + rows = ( + [("yes", "seed")] * 10 + + [("no", "seed")] * 24 + + [("yes", "control")] * 4 + + [("no", "control")] * 12 + ) + return pl.DataFrame({"yawn": [r[0] for r in rows], "group": [r[1] for r in rows]}) + + +def test_prop_test_two_sample_matches_r(): + out = prop_test(_two_group(), formula="yawn ~ group", success="yes", order=("seed", "control")) + # R prop.test(c(10,4), c(34,16)): X-squared 0, p 1, CI [-0.26168, 0.34991] + assert out["statistic"][0] == pytest.approx(0.0, abs=1e-9) + assert out["chisq_df"][0] == 1 + assert out["p_value"][0] == pytest.approx(1.0) + assert out["lower_ci"][0] == pytest.approx(-0.26168, abs=1e-4) + assert out["upper_ci"][0] == pytest.approx(0.34991, abs=1e-4) + + +def test_prop_test_two_sample_no_correction_and_z(): + g = _two_group() + nc = prop_test( + g, formula="yawn ~ group", success="yes", order=("seed", "control"), correct=False + ) + assert nc["statistic"][0] == pytest.approx(0.10504, abs=1e-4) # R X-squared, correct=FALSE + z = prop_test(g, formula="yawn ~ group", success="yes", order=("seed", "control"), z=True) + assert "chisq_df" not in z.columns + assert {"statistic", "p_value", "estimate", "alternative", "lower_ci", "upper_ci"} <= set( + z.columns + ) + + +def test_prop_test_one_sample_matches_r(): + df = pl.DataFrame({"x": ["yes"] * 14 + ["no"] * 36}) + out = prop_test(df, response="x", success="yes", p=0.3) + # R prop.test(14, 50, p=0.3): X-squared 0.02381, p 0.87737, Wilson CI [0.1667, 0.4271] + assert out["statistic"][0] == pytest.approx(0.02381, abs=1e-4) + assert out["p_value"][0] == pytest.approx(0.87737, abs=1e-4) + assert out["lower_ci"][0] == pytest.approx(0.1667, abs=1e-3) + assert out["upper_ci"][0] == pytest.approx(0.4271, abs=1e-3) + # one-sided alternative matches R's "less" p-value + less = prop_test(df, response="x", success="yes", p=0.3, alternative="less") + assert less["p_value"][0] == pytest.approx(0.43869, abs=1e-4) + + +def test_prop_test_one_sample_no_correction_ci(): + df = pl.DataFrame({"x": ["yes"] * 14 + ["no"] * 36}) + out = prop_test(df, response="x", success="yes", p=0.3, correct=False) + # uncorrected Wilson score interval for x=14, n=50 + assert out["lower_ci"][0] == pytest.approx(0.1747, abs=1e-3) + assert out["upper_ci"][0] == pytest.approx(0.4167, abs=1e-3) + + +def test_prop_test_conf_int_false_drops_ci(): + df = pl.DataFrame({"x": ["yes"] * 14 + ["no"] * 36}) + out = prop_test(df, response="x", success="yes", p=0.3, conf_int=False) + assert "lower_ci" not in out.columns and "upper_ci" not in out.columns + + +def test_prop_test_two_sample_needs_order(): + with pytest.raises(ValueError, match="order"): + prop_test(_two_group(), formula="yawn ~ group", success="yes") + + +# ============================ rep sampling ============================== + + +def test_rep_slice_sample_prop(): + bowl = md.load_bowl() + out = rep_slice_sample(bowl, prop=0.1, reps=3, seed=1) + assert out["replicate"].n_unique() == 3 + assert out.filter(pl.col("replicate") == 1).height == round(0.1 * bowl.height) + + +def test_rep_slice_sample_weight_by_column_and_sequence(): + df = pl.DataFrame({"x": list(range(10)), "w": [0.0] * 9 + [1.0]}) + # only the last row has weight → every draw is x == 9 + by_col = rep_slice_sample(df, n=5, replace=True, weight_by="w", seed=1) + assert set(by_col["x"].to_list()) == {9} + by_seq = rep_slice_sample(df, n=5, replace=True, weight_by=[0.0] * 9 + [1.0], seed=1) + assert set(by_seq["x"].to_list()) == {9} + + +def test_rep_sample_n_prob(): + df = pl.DataFrame({"x": list(range(10))}) + out = rep_sample_n(df, n=5, replace=True, prob=[0.0] * 9 + [1.0], seed=1) + assert set(out["x"].to_list()) == {9} + + +def test_rep_slice_sample_n_xor_prop(): + bowl = md.load_bowl() + with pytest.raises(ValueError, match="exactly one"): + rep_slice_sample(bowl, n=5, prop=0.1) + with pytest.raises(ValueError, match="exactly one"): + rep_slice_sample(bowl) + + +def test_rep_slice_sample_bad_weights(): + df = pl.DataFrame({"x": [1, 2, 3], "w": [0.0, 0.0, 0.0]}) + with pytest.raises(ValueError, match="positive weights"): + rep_slice_sample(df, n=2, replace=True, weight_by="w") + + +def test_rep_slice_sample_without_replacement_too_big(): + df = pl.DataFrame({"x": [1, 2, 3]}) + with pytest.raises(ValueError, match="without replacement"): + rep_slice_sample(df, n=5) + + +# ============================ generate(variables=) ====================== + + +def test_generate_variables_permutes_chosen_column(): + gss = md.load_gss() + # permuting the response gives a valid null distribution for the diff in means + null = ( + gss.specify(formula="age ~ college") + .hypothesize(null="independence") + .generate(reps=200, type="permute", variables="age", seed=1) + .calculate(stat="diff in means", order=("degree", "no degree")) + ) + assert null.data.height == 200 + assert abs(float(null.stats.mean())) < 1.0 # centered near 0 under the null + + +def test_generate_variables_must_be_a_model_variable(): + gss = md.load_gss() + with pytest.raises(ValueError, match="must be the response or explanatory"): + gss.specify(formula="age ~ college").hypothesize(null="independence").generate( + reps=5, type="permute", variables="nope" + ) + + +# ============================ shade fill + dens_color =================== + + +@pytest.mark.parametrize("engine", ["plotly", "plotnine"]) +def test_shade_fill_both_engines(engine): + boot = ( + md.load_age_at_marriage() + .specify(response="age") + .generate(reps=200, type="bootstrap", seed=1) + .calculate(stat="mean") + ) + from moderndive import get_confidence_interval + + ci = get_confidence_interval(boot, type="percentile") + p_ci = visualize(boot, engine=engine) + shade_confidence_interval( + ci, color="navy", fill="lightblue" + ) + p_pv = visualize(boot, engine=engine) + shade_p_value( + obs_stat=float(boot.data["stat"].mean()), + direction="right", + color="darkred", + fill="mistyrose", + ) + cls = go.Figure if engine == "plotly" else ggplot + assert isinstance(p_ci.figure, cls) and isinstance(p_pv.figure, cls) + + +@pytest.mark.parametrize("engine", ["plotly", "plotnine"]) +@pytest.mark.parametrize("method", ["theoretical", "both"]) +def test_dens_color_both_engines(engine, method): + boot = ( + md.load_age_at_marriage() + .specify(response="age") + .generate(reps=200, type="bootstrap", seed=1) + .calculate(stat="mean") + ) + p = visualize(boot, engine=engine, method=method, dens_color="green") + cls = go.Figure if engine == "plotly" else ggplot + assert isinstance(p.figure, cls) diff --git a/tests/test_infer_parity.py b/tests/test_infer_parity.py index 17e1a87..5c75e64 100644 --- a/tests/test_infer_parity.py +++ b/tests/test_infer_parity.py @@ -41,10 +41,17 @@ def test_ratio_and_odds_props_manual(): def test_chisq_equals_prop_test_z_squared(): yawn = _yawn() chi = float(specify(yawn, formula="yawn ~ group").calculate(stat="Chisq")) + # prop_test now defaults to the chi-square statistic (R parity); ask for the + # z explicitly. Without continuity correction, chi-square == z**2. z = float( - prop_test(yawn, formula="yawn ~ group", success="yes", order=("seed", "control"))[ - "statistic" - ][0] + prop_test( + yawn, + formula="yawn ~ group", + success="yes", + order=("seed", "control"), + z=True, + correct=False, + )["statistic"][0] ) assert chi == pytest.approx(z**2, rel=1e-6)