Source code for qdiv.plot.ordination_plots

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Patch
from matplotlib.textpath import TextPath
from matplotlib.font_manager import FontProperties
from typing import Dict, List, Optional, Tuple, Union, Any
from ..stats import pcoa_lingoes
from ..utils import get_colors_markers, get_df

# ------------------------------
# Helpers for ordination plot
# ------------------------------
def _extract_ordination_payload(ordination_results):
    """
    Normalize ordination input into a common payload:
      - coords_df      : DataFrame of site scores (rows=samples, cols=axes)
      - axis_names     : list[str] of axis names (e.g., ['PCo1','PCo2'] or ['dbRDA1','dbRDA2'])
      - pct_explained  : pd.Series (% per axis, aligned to axis_names)
      - eigenvalues    : pd.Series (aligned to axis_names)
      - biplot_df      : optional DataFrame (rows=predictors, cols=axes)
      - kind           : 'pcoa' or 'dbrda'

    Accepts:
      - a distance DataFrame (square), or
      - dict from pcoa_lingoes, or
      - dict from dbrda
    """
    # Case A: a distance matrix → compute PCoA using provided pcoa_fn
    if isinstance(ordination_results, pd.DataFrame) and ordination_results.shape[0] == ordination_results.shape[1]:
        res = pcoa_lingoes(ordination_results)
        coords_df = res['site_scores']
        ev = res['eigenvalues']
        pct = res['pct_explained']
        axis_names = list(coords_df.columns)
        eigenvalues = pd.Series(np.array(ev), index=axis_names)
        pct_explained = pd.Series(np.array(pct), index=axis_names)
        return {
            'coords_df': coords_df,
            'axis_names': axis_names,
            'pct_explained': pct_explained,
            'eigenvalues': eigenvalues,
            'biplot_df': None,
            'kind': 'pcoa'}

    # Case B: dict from pcoa_lingoes or dbrda
    if isinstance(ordination_results, dict):
        coords_df = ordination_results.get('site_scores', None)
        if coords_df is None or not isinstance(coords_df, pd.DataFrame):
            raise ValueError("Could not find site scores in ordination dict.")

        biplot_df = ordination_results.get('biplot_scores', None)
        kind = 'dbrda' if biplot_df is not None else 'pcoa'
        axis_names = list(coords_df.columns)

        # Eigenvalues can be Series (PCoA) or ndarray (dbRDA)
        ev = ordination_results.get('eigenvalues', None)
        if ev is None:
            eigenvalues = pd.Series(index=axis_names, dtype=float)
        else:
            eigenvalues = pd.Series(np.array(ev).ravel(), index=axis_names[:len(np.array(ev).ravel())])

        # Explained % under various keys
        if 'pct_explained' in ordination_results:
            pct = ordination_results['pct_explained']  # already in %
        else:
            pct = pd.Series(ordination_results['explained_ratio'] * 100, index=axis_names).round(2)

        return {
            'coords_df': coords_df,
            'axis_names': axis_names,
            'pct_explained': pct,
            'eigenvalues': eigenvalues,
            'biplot_df': biplot_df,
            'kind': kind
        }
    raise TypeError("ordination must be a square distance DataFrame or a dict returned by pcoa_lingoes/dbrda.")


# ------------------------------
# Helpers: arrows & scaling
# ------------------------------
def _compute_pcoa_biplot(coords_2d: pd.DataFrame, meta: pd.DataFrame, variables: list, eigen_x: float, eigen_y: float):
    """
    Compute PCoA biplot arrows for two axes from numeric metadata columns.
    """
    # Standardize U (site scores)
    U_std = coords_2d.copy()
    for j in range(2):
        col = U_std.columns[j]
        std = U_std[col].std()
        U_std[col] = (U_std[col] - U_std[col].mean()) / (std if std and std > 0 else 1.0)

    # Standardize Y from meta
    Y = pd.DataFrame(index=coords_2d.index)
    for mh in variables:
        vals = pd.to_numeric(meta[mh], errors='coerce')
        std = vals.std()
        Y[mh] = (vals - vals.mean()) / (std if std and std > 0 else 1.0)
    Y_cent = Y.transpose()

    # Project
    Spc = (1 / (len(coords_2d.index) - 1)) * np.matmul(Y_cent, U_std.to_numpy())
    biglambda = np.array([[eigen_x ** -0.5 if eigen_x > 0 else 1.0, 0.0],
                          [0.0, eigen_y ** -0.5 if eigen_y > 0 else 1.0]])
    Uproj_arr = (len(coords_2d.index) - 1) ** 0.5 * np.matmul(Spc, biglambda)
    return pd.DataFrame(Uproj_arr, index=Y.columns, columns=coords_2d.columns.tolist())

def _scale_arrows_to_limits(Uproj: pd.DataFrame, xlims, ylims, margin=0.9):
    """
    Scale arrow coordinates to fit inside axis limits.
    """
    if Uproj is None or Uproj.empty:
        return Uproj
    xn, yn = Uproj.columns.tolist()
    max_abs_x = max(1e-12, np.max(np.abs(Uproj[xn])))
    max_abs_y = max(1e-12, np.max(np.abs(Uproj[yn])))
    scale_x = margin * (xlims[1] - xlims[0]) / (2 * max_abs_x)
    scale_y = margin * (ylims[1] - ylims[0]) / (2 * max_abs_y)
    return Uproj * min(scale_x, scale_y)

# ------------------------------
# Helpers: ellipses
# ------------------------------
def _covariance_ellipse_params(x: np.ndarray, y: np.ndarray, n_std: float = 2.0):
    """
    Compute covariance ellipse parameters (center, width, height, angle in degrees).
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    mean_x, mean_y = np.mean(x), np.mean(y)
    cov = np.cov(x, y)
    vals, vecs = np.linalg.eigh(cov)
    order = np.argsort(vals)[::-1]
    vals = vals[order]
    vecs = vecs[:, order]
    width = 2 * n_std * np.sqrt(max(vals[0], 1e-12))
    height = 2 * n_std * np.sqrt(max(vals[1], 1e-12))
    angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))
    return mean_x, mean_y, width, height, angle

def _draw_ellipses(ax, coords: pd.DataFrame, meta: pd.DataFrame, group_col: str,
                   n_std: float = 2.0, edge_color='grey', lw=1.0,
                   label_centers=False, connect_by: str = None, colors=None):
    """
    Draw confidence ellipses for each category in `group_col`.
    Optionally connect ellipse centers ordered by `connect_by`.

    Handles:
      - Single-group case (skips connecting centers)
      - Non-numeric connect_by (skips connection)
      - Insufficient points for ellipse (draw centroid only)
    """
    xn, yn = coords.columns.tolist()
    cats = pd.unique(meta[group_col])
    centers = []

    # Draw ellipses or centroids
    for i, cat in enumerate(cats):
        mask = meta[group_col] == cat
        xs = coords.loc[mask, xn].values
        ys = coords.loc[mask, yn].values
        color = colors[i % len(colors)] if colors is not None else edge_color

        if len(xs) >= 3:  # enough points for ellipse
            cx, cy, w, h, ang = _covariance_ellipse_params(xs, ys, n_std=n_std)
            e = Ellipse((cx, cy), width=w, height=h, angle=ang,
                        facecolor='none', edgecolor=color, lw=lw)
            ax.add_patch(e)
            if label_centers:
                ax.annotate(str(cat), (cx, cy))
            centers.append((cat, cx, cy))
        else:
            # fallback: just mark centroid
            cx, cy = np.mean(xs), np.mean(ys)
            if label_centers:
                ax.annotate(str(cat), (cx, cy))
            centers.append((cat, cx, cy))

    # Connect centers if requested and valid
    if connect_by is not None and len(centers) > 1:
        try:
            # Compute mean of connect_by per group
            order_vals = meta.groupby(group_col)[connect_by].mean()
            if isinstance(order_vals, pd.Series):
                order_vals = order_vals.sort_values()
            else:
                # Single group fallback
                order_vals = pd.Series([order_vals], index=[meta[group_col].unique()[0]])
            df_centers = pd.DataFrame(centers, columns=[group_col, 'x', 'y']).set_index(group_col)
            ordered = df_centers.loc[order_vals.index]
            ax.plot(ordered['x'].values, ordered['y'].values, color='black', lw=lw)
        except Exception as e:
            # Gracefully skip if connect_by is invalid
            print(f"Skipping ellipse connection due to error: {e}")

# ------------------------------
# Helpers: points & legends
# ------------------------------
def _estimate_text_width_in(texts, fontsize=12, fontproperties=None):
    """
    Estimate max text width in inches using TextPath (points -> inches).
    """
    fp = fontproperties or FontProperties(size=fontsize)
    max_pt = 0.0
    for t in texts:
        s = "" if t is None else str(t)
        # Skip truly empty labels: they cause TextPath to return list-based paths.
        if s.strip() == "":
            continue
        try:
            tp = TextPath((0, 0), s, prop=fp, usetex=False)
            bbox = tp.get_extents()
            max_pt = max(max_pt, bbox.width)  # width in points
        except Exception:
            # Fallback heuristic: ~0.6 * fontsize (pt) per character
            max_pt = max(max_pt, 0.6 * fontsize * max(1, len(s)))
    # 1 pt = 1/72 inch
    return max_pt / 72.0

def _auto_legend_fraction_fast(
    meta,
    *,
    color_by,
    shape_by,
    figure_width_in,
    fontsize=12,
    marker_em=1.4,
    pad_em=1.0,
    min_fraction=0.16,
    max_fraction=0.48,
):
    """
    Heuristic legend fraction using text extents only.
    """
    if color_by is None:
        color_labels = ["all"]
    else:
        color_labels = list(pd.unique(meta[color_by]))
    if shape_by is not None:
        if pd.api.types.is_categorical_dtype(meta[shape_by]):
            shape_labels = list(meta[shape_by].cat.categories)
        else:
            shape_labels = list(pd.unique(meta[shape_by]))
    else:
        shape_labels = []

    # Convert "em" (roughly font size) to inches: fontsize pt -> in
    em_in = (fontsize / 72.0)

    color_title = (color_by if color_by and color_by != "_all_" else "")
    shape_title = (shape_by if shape_by else "")

    color_text_in = _estimate_text_width_in([color_title] + color_labels, fontsize=fontsize)
    shape_text_in = _estimate_text_width_in([shape_title] + shape_labels, fontsize=fontsize)

    # add room for markers/patches and padding
    color_w_in = color_text_in + marker_em * em_in + pad_em * em_in
    shape_w_in = shape_text_in + marker_em * em_in + pad_em * em_in

    extra_padding_in = 3 * fontsize / 72.0
    needed_in = max(color_w_in, shape_w_in) + extra_padding_in
    frac = needed_in / max(figure_width_in, 1e-6)
    return float(np.clip(frac, min_fraction, max_fraction))

def _default_colors(n):
    return get_colors_markers(get_type="colors")

def _default_markers(n):
    return get_colors_markers(get_type="markers")

def _draw_points(
    ax, coords, meta,
    *,
    color_by=None, shape_by=None,
    colors=None, markers=None, markersize=50, lw=1.0,
    connect=None, legend=True,
    legend_pos_colors=None, legend_pos_shapes=None,
    legend_titles=(None, None),
    markerscale=1.1,
    fontsize=12,
    ax_leg=None,
    color_legend_style="patch",
    color_legend_marker="o",
    color_legend_title=None,
    shape_legend_title=None,
):
    """
    Scatter plot of points grouped by color_by and shape_by,
    with optional connection lines and separate legend axis.
    """
    xn, yn = coords.columns.tolist()

    # ---- Prepare grouping (avoid mutating input meta) ----
    if color_by is None:
        meta = meta.copy()
        meta["_all_"] = "all"
        color_by = "_all_"

    cats1 = list(pd.unique(meta[color_by]))
    colors = colors or _default_colors(len(cats1))

    # Global shape levels (respect categorical order if present)
    shape_levels: List[str] = []
    if shape_by is not None:
        if pd.api.types.is_categorical_dtype(meta[shape_by]):
            shape_levels = list(meta[shape_by].cat.categories)
        else:
            shape_levels = list(pd.unique(meta[shape_by]))

    # Markers must at least cover the number of distinct shapes
    markers = markers or _default_markers(max(1, len(shape_levels)))

    # Build a global mapping: shape category -> marker
    shape_to_marker: Dict[Any, str] = {}
    if shape_by is not None:
        for idx, cat in enumerate(shape_levels):
            shape_to_marker[cat] = markers[idx % len(markers)]

    # ---- Legend containers ----
    linesColor = [[], []]  # handles, labels
    linesShape = [[], []]

    # Prebuild SHAPE legend once, based on global shape_to_marker
    if shape_by is not None and len(shape_levels) > 0:
        for cat2 in shape_levels:
            mk = shape_to_marker[cat2]
            # neutral/black marker to illustrate shape
            handle = ax.scatter([], [], label=str(cat2), color="black", marker=mk)
            linesShape[0].append(handle)
            linesShape[1].append(cat2)

    # ---- Plot by COLOR groups ----
    for i, cat1 in enumerate(cats1):
        sel1 = meta[color_by] == cat1
        meta_i = meta.loc[sel1]
        coords_i = coords.loc[sel1]

        # Optional connection lines (within color group)
        if connect is not None:
            order_vals = pd.to_numeric(meta_i[connect], errors="coerce")
            sorter = np.argsort(order_vals.values)
            ax.plot(
                coords_i.iloc[sorter, 0], coords_i.iloc[sorter, 1],
                color=colors[i], lw=lw
            )

        # Color legend stub
        if color_legend_style == "patch":
            linesColor[0].append(Patch(facecolor=colors[i], edgecolor="none", label=str(cat1)))
        else:
            linesColor[0].append(
                ax.scatter([], [], label=str(cat1), color=colors[i], marker=color_legend_marker)
            )
        linesColor[1].append(cat1)

        # Plot shapes within this color group
        if shape_by is not None and len(shape_levels) > 0:
            # Iterate only over shapes present in this color slice, but look up marker from the global map
            for cat2 in pd.unique(meta_i[shape_by]):
                sel2 = meta_i[shape_by] == cat2
                xi = coords_i.loc[sel2, xn]
                yi = coords_i.loc[sel2, yn]
                mk = shape_to_marker.get(cat2, markers[0])  # fallback safe-guard
                ax.scatter(xi, yi, color=colors[i], marker=mk, s=markersize)
        else:
            # No shape dimension: use a (color-specific) marker for the color legend
            mk = markers[i % len(markers)]
            # Replace the legend patch with an example marker so legend matches the plot symbols
            if color_legend_style == "patch":
                linesColor[0][-1] = ax.scatter([], [], label=str(cat1), color=colors[i], marker=mk)
            ax.scatter(coords_i[xn], coords_i[yn], color=colors[i], marker=mk, s=markersize)

    # Default legend anchors if not provided
    if legend_pos_colors is None:
        legend_pos_colors = (1.0, 1.0)
    if legend_pos_shapes is None:
        legend_pos_shapes = (1.0, 0.4)

    # ---- Legends in a dedicated legend axis (Option B) ----
    if legend and ax_leg is not None:
        ax_leg.cla()
        ax_leg.axis("off")

        # Color legend (top)
        if len(linesColor[0]) > 0:
            title1 = (
                color_legend_title
                if color_legend_title is not None
                else (legend_titles[0] if legend_titles[0] else (color_by if color_by != "_all_" else ""))
            )
            leg1 = ax_leg.legend(
                handles=linesColor[0],
                labels=[str(lbl) for lbl in linesColor[1]],
                title=title1,
                frameon=False,
                loc="upper left",
                fontsize=fontsize,
                markerscale=markerscale,
            )
            try:
                leg1.get_title().set_ha("left")
                if hasattr(leg1, "_legend_title_box"):
                    leg1._legend_title_box.align = "left"
                if hasattr(leg1, "_legend_box"):
                    leg1._legend_box.align = "left"
            except Exception:
                pass

        # Shape legend (below)
        if shape_by is not None and len(linesShape[0]) > 0:
            from matplotlib.legend import Legend
            title2 = shape_legend_title if shape_legend_title is not None else (legend_titles[1] if legend_titles[1] else shape_by)
            leg2 = Legend(
                ax_leg, linesShape[0], linesShape[1],
                ncol=1, loc="lower left", bbox_to_anchor=(0.0, 0.0),
                frameon=False, title=title2,
                fontsize=fontsize, markerscale=markerscale,
            )
            ax_leg.add_artist(leg2)
            try:
                leg2.get_title().set_ha("left")
                if hasattr(leg2, "_legend_title_box"):
                    leg2._legend_title_box.align = "left"
                if hasattr(leg2, "_legend_box"):
                    leg2._legend_box.align = "left"
            except Exception:
                pass

        return  # done with Option B

    # ---- Fallback: legends anchored to the plotting axis ----
    if legend:
        # Color legend
        title1 = legend_titles[0] if legend_titles[0] else (color_by if color_by != "_all_" else "")
        leg1 = ax.legend(
            linesColor[0],
            [str(lbl) for lbl in linesColor[1]],
            ncol=1,
            bbox_to_anchor=legend_pos_colors,
            title=title1,
            frameon=False,
            markerscale=markerscale,
            fontsize=fontsize,
            loc=2,
        )
        try:
            leg1.get_title().set_ha("left")
            if hasattr(leg1, "_legend_title_box"):
                leg1._legend_title_box.align = "left"
            if hasattr(leg1, "_legend_box"):
                leg1._legend_box.align = "left"
            leg1.set_in_layout(True)
        except Exception:
            pass

        # Shape legend
        if shape_by is not None and len(linesShape[0]) > 0:
            from matplotlib.legend import Legend
            title2 = legend_titles[1] if legend_titles[1] else shape_by
            leg2 = Legend(
                ax, linesShape[0], linesShape[1],
                ncol=1,
                bbox_to_anchor=legend_pos_shapes,
                title=title2,
                frameon=False,
                markerscale=markerscale,
                fontsize=fontsize,
                loc=2,
            )
            ax.add_artist(leg2)
            try:
                leg2.get_title().set_ha("left")
                if hasattr(leg2, "_legend_title_box"):
                    leg2._legend_title_box.align = "left"
                if hasattr(leg2, "_legend_box"):
                    leg2._legend_box.align = "left"
                leg2.set_in_layout(True)
            except Exception:
                pass


# ------------------------------
# Main: ordination plot function
# ------------------------------
[docs] def ordination( ordination_results: Union[pd.DataFrame, Dict[str, Union[pd.DataFrame, dict]]] = None, meta: Union[pd.DataFrame, Dict[str, Any], Any] = None, *, color_by: Optional[str] = None, shape_by: Optional[str] = None, order: Optional[str] = None, biplot: Optional[List[str]] = None, ellipse: Optional[str] = None, title: str = "", savename: Optional[str] = None, show_legend: bool = True, figsize: Tuple[float, float] = (9, 6), fontsize: int = 12, markersize: float = 50, markerscale: float = 1.1, lw: float = 1.0, pad: float = 1.1, flipx: bool = False, flipy: bool = False, hide_ticks: bool = False, connect: Optional[str] = None, ellipse_connect: Optional[str] = None, ellipse_std: float = 2.0, tag: Optional[str] = None, return_data: bool = False, colors: Optional[List[str]] = None, markers: Optional[List[str]] = None, ellipse_colors: Optional[List[str]] = None, color_legend_marker: Optional[str] = None, color_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, which_axes: Tuple[int, int] = (0, 1), ax: Optional[plt.Axes] = None, legend_pos_colors: Tuple[float, float] = (1.0, 1.0), legend_pos_shapes: Tuple[float, float] = (1.0, 0.4), ) -> Tuple["plt.Figure", "plt.axes", "pd.DataFrame", "pd.DataFrame"]: """ Create an ordination plot (PCoA or db-RDA) with optional grouping, biplots, and annotations. Parameters ---------- ordination_results : pandas.DataFrame or dict - If a dissimilarity matrix (square DataFrame) is provided, stats.pcoa_lingoes will be run and a PCoA plotted. - If a results dict from stats.pcoa_lingoes or stats.dbrda is provided, those results will be plotted directly. meta : DataFrame | MicrobiomeData-like | dict Metadata table with sample annotations. color_by : str, optional Column in `meta` used to color points by group. shape_by : str, optional Column in `meta` used to vary marker shapes by group. order : str, optional Metadata column used to order samples for the color_by variable. biplot : list of str, optional For PCoA: list of numeric metadata columns to display as biplot vectors. For db-RDA: set to None to use 'biplot_scores' from ordination results. ellipse : str, optional Column in `meta` used to group samples for drawing confidence ellipses. title : str, optional Plot title. savename : str, optional Filename to save the figure. Extension determines format (e.g., `.png`, `.pdf`). show_legend : bool, default=True Whether to display the legend. figsize : tuple of float, default=(9, 6) Figure size in inches. fontsize : int, default=12 Font size for labels and title. markersize : float, default=50 Size of scatter plot markers. markerscale : float, default=1.1 Scaling factor for legend markers. lw : float, default=1.0 Line width for ellipses and connections. pad : float, default=1.1 Padding factor for axis limits. flipx : bool, default=False Flip the X-axis. flipy : bool, default=False Flip the Y-axis. hide_ticks : bool, default=False Hide axis ticks and labels. connect : str, optional Column in `meta` to connect points in order (e.g., time series). ellipse_connect : str, optional Column in `meta` to connect ellipse centers in order. ellipse_std : float, default=2.0 Number of standard deviations around the centroid that the ellipse is drawn. tag : str, optional Column in `meta` or 'index' to annotate points. return_data : bool, default=False If True, return processed plotting data instead of the figure. colors : list of str, optional Custom list of colors for groups. markers : list of str, optional Custom list of marker styles for groups. ellipse_colors : list of str, optional Custom list of colors for ellipses. color_legend_marker : str, default=None Shape of color legend marker. Defaults to a rectangular patch. color_legend_title : str, optional Color legend title. shape_legend_title : str, optional Marker legend title. which_axes : tuple of int, default=(0, 1) Indices of ordination axes to plot. ax : matplotlib.axes.Axes, optional Existing axes to draw the plot on. If None, a new figure is created. legend_pos_colors : tuple of float, default=(1, 1) Position of color legend. Only relevant for user-supplied ax. legend_pos_shapes : tuple of float, default=(1, 0.4) Position of shape legend. Only relevant for user-supplied ax. Returns ------- fig : matplotlib.figure.Figure The matplotlib Figure object for the ordination. ax : matplotlib.axes.Axes The matplotlib Axes object for the ordination. meta : pandas.DataFrame meta data with ordination coordinates. Uproj: pandas.DataFrame biplot coordinates (if used) Notes ----- - Supports both PCoA and db-RDA ordination results. - Ellipses represent group dispersion; biplots show variable contributions. - Axis flipping and padding allow fine control over plot appearance. Examples -------- >>> fig = ordination(pcoa_results, meta, color_by='Treatment', ellipse='Group') >>> fig.savefig('ordination_plot.png') """ # Validation if ordination_results is None: raise ValueError('ordination_results are missing.') meta = get_df(meta, "meta") if meta is None: raise ValueError('meta data is missing.') # Extract normalized ordination payload payload = _extract_ordination_payload(ordination_results) coords_df = payload['coords_df'].copy() axis_names = payload['axis_names'] pct_explained = payload['pct_explained'] eigenvalues = payload['eigenvalues'] biplot_df = payload['biplot_df'] kind = payload['kind'] # 'pcoa' or 'dbrda' # Align meta to ordination sample order if not coords_df.index.equals(meta.index): common = coords_df.index.intersection(meta.index) if len(common) != len(coords_df.index): raise ValueError("Samples in metadata don't match samples in ordination site scores.") meta = meta.loc[coords_df.index] if order is not None and order not in meta.columns: raise ValueError("order is missing in metadata.") elif order is not None: meta = meta.sort_values(by=order, ascending=True) coords_df = coords_df.loc[meta.index] # Pick axes if len(axis_names) < max(which_axes) + 1: raise ValueError(f"Requested axes {which_axes} exceed available axes ({len(axis_names)}).") xn_name = axis_names[which_axes[0]] yn_name = axis_names[which_axes[1]] # Subset coordinates to 2D coords = coords_df[[xn_name, yn_name]].copy() # Axis labels with explained % def _axis_label(name): pct = pct_explained.get(name, np.nan) suffix = "" if pd.isna(pct) else f" ({pct:.2f}%)" return f"{name}{suffix}" xlab = _axis_label(xn_name) ylab = _axis_label(yn_name) # Plot bounds xaxislims = [coords.iloc[:, 0].min() * pad, coords.iloc[:, 0].max() * pad] yaxislims = [coords.iloc[:, 1].min() * pad, coords.iloc[:, 1].max() * pad] # Compute biplot arrows Uproj = None if biplot_df is not None: # dbRDA: use given biplot scores arrows2d = biplot_df[[xn_name, yn_name]].copy() Uproj = _scale_arrows_to_limits(arrows2d, xaxislims, yaxislims) elif biplot and kind == 'pcoa': evx = float(eigenvalues.get(xn_name, 1.0)) evy = float(eigenvalues.get(yn_name, 1.0)) Uproj = _compute_pcoa_biplot(coords, meta, biplot, evx, evy) Uproj = _scale_arrows_to_limits(Uproj, xaxislims, yaxislims) # --- Plot setup (Option B: dedicated legend column when ax is None) --- if ax is None: plt.rcParams.update({'font.size': fontsize}) fig = plt.figure(figsize=figsize, constrained_layout=True) # Reserve right column for legends auto_frac = _auto_legend_fraction_fast( meta, color_by=color_by, shape_by=shape_by, figure_width_in=fig.get_figwidth(), fontsize=fontsize, marker_em=markerscale, pad_em=1.0, min_fraction=0.16, max_fraction=0.48, ) gs = fig.add_gridspec(ncols=2, nrows=1, width_ratios=[1.0 - auto_frac, auto_frac]) ax = fig.add_subplot(gs[0, 0]) ax_leg = fig.add_subplot(gs[0, 1]) ax_leg.axis("off") # hide ticks/frames on legend column else: fig = ax.figure ax_leg = None # fallback to axes-anchored legends if external ax is supplied # --- Determine color legend style from marker preference --- # If user does not provide a marker for the color legend, use neutral rectangles ("patch") if color_legend_marker is None: color_legend_style = "patch" else: color_legend_style = "marker" # --- Points & legends --- _draw_points( ax, coords, meta, color_by=color_by, shape_by=shape_by, colors=colors, markers=markers, markersize=markersize, lw=lw, connect=connect, legend=show_legend, markerscale=markerscale, fontsize=fontsize, legend_pos_colors=legend_pos_colors, legend_pos_shapes=legend_pos_shapes, ax_leg=ax_leg, color_legend_style=color_legend_style, color_legend_marker=color_legend_marker, color_legend_title=color_legend_title, shape_legend_title=shape_legend_title ) # Ellipses if ellipse is not None and ellipse in meta.columns: e_counts = len(meta[ellipse].unique()) if ellipse_colors is None: e_colors = _default_colors(len(pd.unique(meta[ellipse]))) elif isinstance(ellipse_colors, str): e_colors = [ellipse_colors]*e_counts elif isinstance(ellipse_colors, list) and len(ellipse_colors) < e_counts: e_colors = ellipse_colors * (int(e_counts/len(ellipse_colors))+1) elif isinstance(ellipse_colors, list): e_colors = ellipse_colors else: raise ValueError("ellipse_colors is not correctly defined. Should be list, str, or None.") _draw_ellipses( ax, coords, meta, group_col=ellipse, n_std=ellipse_std, edge_color='grey', lw=lw, label_centers=False, connect_by=ellipse_connect, colors=e_colors ) # Arrow overlay if Uproj is not None and len(Uproj) > 0: xn, yn = Uproj.columns.tolist() for var_name in Uproj.index: vx = float(Uproj.loc[var_name, xn]) vy = float(Uproj.loc[var_name, yn]) ha = 'left' if vx > 0 else ('right' if vx < 0 else 'center') va = 'bottom' if vy > 0 else ('top' if vy < 0 else 'center') ax.arrow(0, 0, vx, vy, color='black', width=0.001) ax.annotate(var_name, (1.03 * vx, 1.03 * vy), horizontalalignment=ha, verticalalignment=va) ax.axhline(0, 0, 1, linestyle='--', color='grey', lw=0.5) ax.axvline(0, 0, 1, linestyle='--', color='grey', lw=0.5) # Point/ellipse tags if tag is not None: if tag == 'index': for ix in meta.index: ax.annotate(str(ix), (coords.loc[ix, xn_name], coords.loc[ix, yn_name])) elif tag in meta.columns: for ix in meta.index: tagtext = str(meta.loc[ix, tag]) ax.annotate(tagtext, (coords.loc[ix, xn_name], coords.loc[ix, yn_name])) # Final formatting ax.set_xlabel(xlab) ax.set_ylabel(ylab) ax.set_xlim(xaxislims) ax.set_ylim(yaxislims) if flipx: ax.invert_xaxis() if flipy: ax.invert_yaxis() if hide_ticks: ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_title(title) # Save if savename is not None: fig.savefig(savename, bbox_inches="tight", dpi=240) fig.savefig(savename + ".pdf", format="pdf", bbox_inches="tight") # Return if Uproj is not None: return fig, ax, pd.concat([meta, coords], axis=1), Uproj return fig, ax, pd.concat([meta, coords], axis=1), None