Source code for qdiv.plot.diversity_plots

from typing import Any, Dict, List, Optional, Tuple, Union
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from ..io import subset_samples
from ..diversity import naive_alpha, phyl_alpha, func_alpha
from ..diversity import dissimilarity_by_feature, naive_multi_beta, phyl_multi_beta, func_multi_beta
from ..utils import get_df, parse_leaves, get_colors_markers

# -----------------------------------------------------------------------------
# Plot dissimilarity contribution of features
# -----------------------------------------------------------------------------
[docs] def dissimilarity_contributions( obj: Union[Dict[str, Any], Any], *, by: Optional[str] = None, q: float = 1.0, div_type: str = "naive", index: str = "local", n: int = 20, levels: Optional[List[str]] = None, from_file: Optional[str] = None, figsize: Tuple[float, float] = (18 / 2.54, 14 / 2.54), fontsize: int = 10, savename: Optional[str] = None, ) -> Tuple["plt.figure.Figure", "pd.DataFrame"]: """ Plot contributions of taxa to observed dissimilarity within categories. This function visualizes how individual taxa contribute to dissimilarity (e.g., Bray-Curtis or Hill-based) across sample groups defined by metadata. Parameters ---------- obj : dict or MicrobiomeData Input data containing at least: - 'tab': pandas.DataFrame Abundance table (features x samples). - 'tax': pandas.DataFrame Taxonomy table (features x taxonomic levels). - ``meta`` (pandas.DataFrame): metadata table. by : str, optional Metadata column used to categorize samples. Dissimilarity is calculated within each category. q : float, default=1.0 Diversity order (Hill number). div_type : {'naive', 'phyl'}, default='naive' Diversity type: - 'naive': taxonomic dissimilarity. - 'phyl': phylogenetic dissimilarity. index : {'local', 'regional'}, default='local' Index type for dissimilarity calculation. n : int, default=20 Number of top taxa to include in the plot. levels : list of str, optional Taxonomic levels to include in y-axis labels (e.g., ['Genus']). from_file : str, optional Path to a CSV file with precomputed dissimilarity contributions. If None, contributions are computed from `obj`. figsize : tuple of float, default=(18/2.54, 14/2.54) Figure size in inches. fontsize : int, default=10 Font size for plot text. savename : str, optional If provided, save the figure to this path and also as PDF. Returns ------- fig : matplotlib.figure.Figure df : pandas.DataFrame DataFrame of contributions for plotted taxa and categories. Returns None if computation or plotting fails. Notes ----- - If `from_file` is provided, the function reads contributions from that file. - If `levels` is provided and `div_type='naive'`, taxonomy names are appended to feature IDs. - For phylogenetic diversity, node names or feature sets are used for labeling. Examples -------- >>> df = dissimilarity_contributions(obj, by='Treatment', q=1, div_type='naive', levels=['Genus']) >>> print(df.head()) """ tax = get_df(obj, "tax") tree = get_df(obj, "tree") if levels is not None and isinstance(levels, str): levels = [levels] # Load or compute dissimilarity contributions if from_file is None: # Compute contributions using your diversity function dis_data = dissimilarity_by_feature(obj, by=by, q=q, div_type=div_type, index=index) else: dis_data = pd.read_csv(from_file, index_col=0) # Prepare data df = dis_data.drop(["N", "dis"], axis=0) catlist = [x for x in df.columns if 'nodes' not in x] df["avg"] = df[catlist].mean(axis=1) df = df.sort_values(by="avg", ascending=False).iloc[:n] # Categories and taxa labels ylist = range(len(df.index)) taxlist = df.index.tolist() # Append taxonomy names if requested if levels is not None and tax is not None and div_type == "naive": tax_df = tax.loc[df.index].fillna("").astype(str) for i, asv in enumerate(df.index): taxname = "; ".join([tax_df.loc[asv, lvl] for lvl in levels if len(tax_df.loc[asv, lvl]) > 3]) taxlist[i] = f"{taxname}; {asv}" if taxname else asv elif levels is not None and div_type == "phyl" and tree is not None: tree_df = tree.loc[df.index] tax = tax.fillna("").astype(str) for i, ix in enumerate(tree_df.index): asvlist = tree_df.loc[ix, "leaves"] if len(asvlist) == 1: taxname = "; ".join([tax.loc[asvlist[0], lvl] for lvl in levels if len(tax.loc[asvlist[0], lvl]) > 3]) taxlist[i] = f"{taxname}; {asvlist[0]}" if taxname else asvlist[0] else: taxlist[i] = tree_df.loc[ix, "nodes"] # Plot plt.rcParams.update({"font.size": fontsize}) fig = plt.figure(figsize=figsize, constrained_layout=True) gs = fig.add_gridspec(1, len(catlist)) gs.update(wspace=0, hspace=0) for cat_nr, cat in enumerate(catlist): ax = fig.add_subplot(gs[0, cat_nr], frame_on=True) ax.barh(ylist, df[cat]) ax.set_yticks(range(len(df.index))) ax.set_yticklabels(taxlist if cat_nr == 0 else []) ax.set_xlabel("%") title_text = f"{cat}\nN={int(dis_data.loc['N', cat])}\n$^{{{q}}}$d={round(dis_data.loc['dis', cat], 2)}" ax.set_title(title_text) if savename: plt.savefig(savename, dpi=240) try: plt.savefig(f"{savename}.pdf", format="pdf") except Exception: # Fallback silently if a PDF backend is not available in the environment pass return fig, df
# ----------------------------------------------------------------------------- # Plot phylogram # -----------------------------------------------------------------------------
[docs] def phyl_tree( obj: Union[Dict[str, Any], Any], *, width: float = 12, name_internal_nodes: bool = False, abundance_info: Optional[str] = None, xlog: bool = False, savename: Optional[str] = None, ) -> Tuple["plt.figure.Figure", "pd.DataFrame"]: """ Plot a phylogram from a tree DataFrame with optional abundance bars. Parameters ---------- obj : dict or MicrobiomeData Input object with required key: - ``tree`` (pandas.DataFrame): tree structure with columns ['nodes', 'leaves', 'branchL']. Optional keys: - ``tab`` (pandas.DataFrame): abundance table (features x samples). - ``meta`` (pandas.DataFrame): metadata table for sample grouping. width : float, default=12 Width of the plot in centimeters. Height is set automatically based on number of ASVs. name_internal_nodes : bool, default=False If True, labels are added to internal nodes. abundance_info : {'index'} or str, optional If 'index', plot relative abundance bars for each ASV. If a metadata column name, plot grouped abundance bars for each category. xlog : bool, default=False If True, abundance bars use a logarithmic x-axis. savename : str, optional If provided, save the figure to this path and also as PDF. Returns ------- fig : matplotlib.figure.Figure df_endN : pandas.DataFrame DataFrame of end nodes with positions and optional abundance info. Notes ----- - The tree DataFrame must contain columns: 'nodes', 'leaves', 'branchL'. - If `abundance_info` is provided, relative abundances are computed per leaf or category. - Bars are plotted to the right of the tree when `abundance_info` is not None. Examples -------- >>> phyl_tree(obj, width=15, name_internal_nodes=True, abundance_info='Treatment', xlog=True, savename='phylogram') """ # Validate tree tree = get_df(obj, "tree") tab = get_df(obj, "tab") meta = get_df(obj, "meta") if tree is None: raise ValueError("Error: 'tree' not found in obj.") df = tree.copy() # Separate end nodes and internal nodes df_endN = df[(~df["nodes"].str.startswith('in'))&(df["nodes"] != 'Root')].set_index("nodes") df_intN = df[df["nodes"].str.startswith('in')].set_index("nodes") # Assign initial positions df_endN["ypos"] = range(len(df_endN.index)) df_endN["xpos"] = df_endN["dist_to_root"].astype(float) # Sort internal nodes by size df_intN['asv_count'] = 0 df_intN['asv_count'] = df_intN['leaves'].apply(lambda x: len(parse_leaves(x))) df_intN = df_intN.sort_values("asv_count", ascending=True) # Compute abundance info if requested catlist = [] if abundance_info and tab is not None and meta is not None: if abundance_info != "index": catlist = meta[abundance_info].dropna().unique().tolist() else: catlist = meta.index.tolist() for cat in catlist: df_endN[f"ra:{cat}"] = 0.0 temp_obj = subset_samples(obj, by=abundance_info, values=[cat]) temp_tab = get_df(temp_obj, "tab") ra = temp_tab / temp_tab.sum() ra = ra.mean(axis=1) df_endN.loc[ra.index, f"ra:{cat}"] = ra # Plot tree textspacing = df_endN["xpos"].max() / 50 plt.rcParams.update({"font.size": 10}) fig = plt.figure(figsize=(width / 2.54, 0.7 * len(df_endN.index) / 2.54), constrained_layout=True) gs = fig.add_gridspec(1, 10) gs.update(wspace=0, hspace=0) ax = fig.add_subplot(gs[0, :9] if abundance_info else gs[0, :10], frame_on=True) # Plot end nodes for node in df_endN.index: ypos = df_endN.loc[node, "ypos"] xpos = df_endN.loc[node, "xpos"] ax.text(xpos + textspacing, ypos, node, va="center", color="red") node_BL = df_endN.loc[node, "branchL"] ax.plot([xpos - node_BL, xpos], [ypos, ypos], lw=1, color="black") df_endN.loc[node, "xpos"] = xpos - node_BL # Plot internal nodes for intN in df_intN.index: asvlist = df_intN.loc[intN, "leaves"] xpos = df_endN.loc[asvlist, "xpos"].mean() ymax = df_endN.loc[asvlist, "ypos"].max() ymin = df_endN.loc[asvlist, "ypos"].min() ax.plot([xpos, xpos], [ymin, ymax], lw=1, color="black") ymean = (ymax + ymin) / 2 xmin = xpos - df_intN.loc[intN, "branchL"] ax.plot([xmin, xpos], [ymean, ymean], lw=1, color="black") df_endN.loc[asvlist, ["ypos", "xpos"]] = [float(ymean), float(xmin)] if name_internal_nodes: ax.text(xpos, ymean, df_intN.loc[intN, "nodes"], va="center", color="red") ax.plot([0, 0], [df_endN["ypos"].min(), df_endN["ypos"].max()], lw=1, color="black") ax.set_ylim(-1, len(df_endN.index)) ax.axis("off") # Plot abundance bars if abundance_info: ax2 = fig.add_subplot(gs[0, 9], frame_on=True) bars_leg1, bars_leg2 = [], [] orig_ypos = range(len(df_endN.index)) for cat_nr, cat in enumerate(catlist): bar_thickness = 0.8 / len(catlist) bar_yoffset = bar_thickness * (cat_nr - (len(catlist) - 1) / 2) ylist = np.array(orig_ypos) + bar_yoffset xlist = df_endN[f"ra:{cat}"] bl = ax2.barh(ylist, xlist, height=0.95 * bar_thickness, label=cat) bars_leg1.append(bl) bars_leg2.append(cat) if xlog: ax2.set_xscale("log") ax2.set_ylim(-1, len(df_endN.index)) ax2.set_xticks([]) ax2.set_yticks([]) ax.legend(bars_leg1, bars_leg2, loc="lower right", bbox_to_anchor=(1, 1), ncol=4, frameon=False) if savename: plt.savefig(savename, dpi=240) try: plt.savefig(f"{savename}.pdf", format="pdf") except Exception: # Fallback silently if a PDF backend is not available in the environment pass return fig, df_endN
# ----------------------------------------------------------------------------- # Plot harvey balls from metadata # -----------------------------------------------------------------------------
[docs] def harvey_balls( meta: Union[pd.DataFrame, Dict[str, Any], Any], columns_by: List[str] = None, *, rows_by: str = "index", row_colors: Optional[str] = None, column_colors: Optional[List[str]] = None, row_label_width: int = 4, figsize: Tuple[float, float] = (18 / 2.54, 14 / 2.54), fontsize: int = 10, savename: Optional[str] = None, ) -> Tuple["plt.figure.Figure", "pd.DataFrame"]: """ Plot Harvey balls (fraction-of-circle indicators) for percentage columns in metadata. Parameters ---------- meta : DataFrame | MicrobiomeData-like | dict Object with metadata table. Must contain the `columns_by` fields and optionally a `rows_by` field used to derive row labels. columns_by : list of str List of metadata column names containing percentages (0–100) to visualize as Harvey balls across rows. rows_by : str, default='index' Name of the metadata column used as row labels. If 'index', the DataFrame index is used as row labels. row_colors : str, optional Name of metadata column containing per-row text colors (e.g., 'red', '#333'). If None, all row labels are drawn in black. column_colors : list of str, optional Colors for the column headers (one per `columns_by`). If None, defaults to black for all headers; if provided but shorter than `columns_by`, the list is padded with black. row_label_width : int, default=4 Number of GridSpec columns reserved for the row label area (left-hand text). figsize : tuple of float, default=(18/2.54, 14/2.54) Figure size in inches. fontsize : int, default=10 Base font size for the figure. savename : str, optional If provided, saves the figure (PNG) to this path and also as PDF (`savename + '.pdf'`). Returns ------- fig : matplotlib.figure.Figure plot_data : pandas.DataFrame A DataFrame containing the row labels and selected percentage values used for plotting: columns ['__label__', *columns_by]. Returns None if validation fails. Notes ----- - Harvey balls are drawn using pie charts where the **black** wedge represents the percentage, and the **white** wedge represents the complement to 100%. - All values in `columns_by` must be numeric (0–100). Non-numeric rows are coerced if possible; rows with missing values will still be plotted (missing values treated as 0). - If `rows_by='index'`, row labels are taken from `meta.index`; otherwise, from `meta[rows_by]`. Examples -------- >>> df = harvey_balls( ... meta, ... rows_by='Treatment', ... columns_by=['PFAS_%', 'DOC_%'], ... row_colors='TreatmentColor', ... column_colors=['#1f77b4', '#ff7f0e'], ... savename='harvey_balls' ... ) >>> print(df.head()) """ # ---- Validation --------------------------------------------------------- meta = get_df(meta, "meta") if meta is None or meta.empty: raise ValueError("Error: meta is missing.") if columns_by is None: raise ValueError("Error: columns_by must be specified.") if isinstance(columns_by, str): columns_by = [columns_by] if not set(columns_by) & set(meta.columns): raise ValueError("columns_by must be a column in metadata") # rows_by must exist unless using index if rows_by != "index" and rows_by not in meta.columns: raise ValueError("columns_by must be a column in metadata") # prepare column header colors if column_colors is None or len(column_colors) == 0: column_colors = ["black"] * len(columns_by) elif len(column_colors) < len(columns_by): column_colors = column_colors + ["black"] * (len(columns_by) - len(column_colors)) # ---- Prepare labels and colors ----------------------------------------- if rows_by == "index": row_labels = meta.index.tolist() else: row_labels = meta[rows_by].astype(str).tolist() if row_colors is None: per_row_text_colors = ["black"] * len(row_labels) else: if row_colors not in meta.columns: print(f"Error: 'row_colors' column '{row_colors}' not found in metadata.") return None per_row_text_colors = meta[row_colors].astype(str).tolist() # coerce percentage columns to numeric, fill NAs with 0, clip to [0, 100] plot_data = meta[columns_by].apply(pd.to_numeric, errors="coerce").fillna(0.0).clip(lower=0.0, upper=100.0) plot_data.insert(0, "__label__", row_labels) # ---- Plot --------------------------------------------------------------- plt.rcParams.update({"font.size": fontsize}) fig = plt.figure(figsize=figsize, constrained_layout=True) # rows: +1 for header row; columns: label area + number of metrics gs = GridSpec(len(meta) + 1, row_label_width + len(columns_by), figure=fig) # Header: left label ax01 = plt.subplot(gs[0, :row_label_width]) ax01.text(0, 0, rows_by, va="center") ax01.set_ylim(-1, 1) ax01.axis("off") # Header: metric names for i, g in enumerate(columns_by): ax = plt.subplot(gs[0, row_label_width + i]) ax.text(0, 0, g, color=column_colors[i], ha="center", va="center") ax.set_ylim(-1, 1) ax.set_xlim(-1, 1) ax.axis("off") # Rows for j, label in enumerate(row_labels): # Row label area ax1 = plt.subplot(gs[j + 1, :row_label_width]) ax1.text(0, 0, label, color=per_row_text_colors[j], va="center") ax1.set_ylim(-1, 1) ax1.axis("off") # Harvey balls per metric for i, g in enumerate(columns_by): ax = plt.subplot(gs[j + 1, row_label_width + i]) black = float(plot_data.loc[meta.index[j], g]) # percentage white = 100.0 - black if white == 100.0: ax.pie([white], colors=["white"], startangle=90, wedgeprops={"linewidth": 1, "edgecolor": "black"}) else: ax.pie([white, black], colors=["white", "black"], startangle=90, wedgeprops={"linewidth": 1, "edgecolor": "black"}) ax.set_aspect("equal") # Save if savename: plt.savefig(savename, dpi=240) try: plt.savefig(f"{savename}.pdf", format="pdf") except Exception: # Fallback silently if a PDF backend is not available in the environment pass return fig, plot_data
# ----------------------------------------------------------------------------- # Plot alpha diversity profiles # -----------------------------------------------------------------------------
[docs] def alpha_diversity_profile( obj: Union[dict, Any], *, q_range: Tuple[float, float] = (0.0, 2.0), q_step: float = 0.05, distmat: Optional[pd.DataFrame] = None, div_type: str = "naive", color_by: Optional[str] = None, order: Optional[str] = None, ylog: bool = False, figsize: Tuple[float, float] = (18 / 2.54, 14 / 2.54), fontsize: int = 10, colorlist: Optional[List[str]] = None, use_values_in_tab: bool = False, savename: Optional[str] = None, ) -> Tuple["plt.Figure", "plt.Axes", "pd.DataFrame"]: """ Plot alpha diversity vs diversity order across samples. This function computes alpha diversity (Hill numbers) for a range of diversity orders q and plots the curves per sample. It supports taxonomic, phylogenetic, and functional diversity depending on `div_type`. Parameters ---------- obj : dict or MicrobiomeData Input data structure containing at least: - ``tab`` (pandas.DataFrame): abundance table (features x samples). - ``meta`` (pandas.DataFrame): sample metadata. If `div_type='phyl'`, must also contain: - ``tree`` (pandas.DataFrame or compatible structure): phylogenetic tree info. q_range : tuple of float, default=(0.0, 2.0) Inclusive range (start, end) of diversity orders to evaluate. q_step : float, default=0.05 Step size between q values. Must be positive. distmat : pandas.DataFrame, optional Functional distance matrix (features x features). Required when `div_type='func'`. div_type : {'naive', 'phyl', 'func'}, default='naive' Diversity type: - 'naive': taxonomic alpha diversity. - 'phyl' : phylogenetic alpha diversity (requires ``tree`` in `obj`). - 'func' : functional alpha diversity (requires `distmat`). color_by : str, optional Metadata column name used to group legend colors. If None, each sample is labeled individually. order : str, optional Metadata column name used to sort samples before plotting. ylog : bool, default=False If True, plot alpha diversity on a logarithmic y-scale. figsize : tuple of float, default=(18/2.54, 14/2.54) Figure size in inches. fontsize : int, default=10 Font size for plot text. colorlist : list of str, optional List of colors to use for groups or samples. If None, uses package defaults via ``get_colors_markers('colors')`` or Matplotlib's cycle. use_values_in_tab : bool, default=False Pass-through flag to alpha diversity backends (e.g., whether `tab` is already normalized). savename : str, optional If provided, saves the figure to this path and also as PDF (i.e., `savename` and `savename + '.pdf'`). Returns ------- fig : matplotlib.figure.Figure The created figure. ax : matplotlib.axes.Axes The matplotlib Axes object for the figure. df : pandas.DataFrame DataFrame with rows = q-values and columns = samples, containing computed alpha diversity values. Notes ----- - For `div_type='phyl'`, `get_df(obj, 'tree')` must exist. - For `div_type='func'`, `distmat` must be provided and compatible with `tab`. - The legend groups are deduplicated using the values of `color_by`. Only the first occurrence of each group is shown in the legend. Examples -------- >>> fig, ax, df = alpha_diversity(obj, q_range=(0, 2), q_step=0.1, ... div_type='naive', color_by='Treatment') >>> df.head() """ # --- Validate inputs ------------------------------------------------------ tab = get_df(obj, "tab") meta = get_df(obj, "meta") if tab is None or not isinstance(tab, pd.DataFrame): raise ValueError("`obj` must contain a 'tab' DataFrame.") if meta is None or not isinstance(meta, pd.DataFrame): raise ValueError("`obj` must contain a 'meta' DataFrame.") if not isinstance(q_range, (tuple, list)) or len(q_range) != 2: raise ValueError("`q_range` must be a (start, end) tuple of floats.") q_start, q_end = float(q_range[0]), float(q_range[1]) if q_start > q_end: raise ValueError("`q_range` must satisfy start <= end.") if q_step <= 0: raise ValueError("`q_step` must be positive.") div_type = str(div_type).lower() if div_type not in {"naive", "phyl", "func"}: raise ValueError("`div_type` must be one of {'naive', 'phyl', 'func'}.") tree = None if div_type == "phyl": tree = get_df(obj, "tree") if tree is None: raise ValueError("`div_type='phyl'` requires `obj` to provide a 'tree' via get_df(obj, 'tree').") if div_type == "func": if distmat is None or not isinstance(distmat, pd.DataFrame): raise ValueError("`div_type='func'` requires `distmat` as a pandas.DataFrame.") # Optional compatibility check: feature axis alignment common = set(distmat.index).intersection(tab.index) if len(common) == 0: raise ValueError("No overlapping features between `tab` rows and `distmat` index. " "Ensure both use the same feature IDs.") # Sort samples if requested if order is not None: if order not in meta.columns: raise KeyError(f"`order` column '{order}' not found in metadata.") meta = meta.sort_values(by=order) smplist = meta.index.tolist() missing_cols = [c for c in smplist if c not in tab.columns] if missing_cols: raise ValueError(f"The following samples from meta are missing in tab columns: {missing_cols}") # --- Prepare q values and result DataFrame -------------------------------- # Include end point with a small epsilon to avoid floating point exclusion xvalues = np.arange(q_start, q_end + (q_step / 2.0), q_step) df = pd.DataFrame(index=xvalues, columns=smplist, dtype=float) # Subset abundance table to the sorted samples tab_use = tab[smplist] # --- Compute alpha diversity per q ---------------------------------------- for q in xvalues: if div_type == "naive": alphadiv = naive_alpha(tab_use, q=q, use_values_in_tab=use_values_in_tab) elif div_type == "phyl": alphadiv = phyl_alpha(obj, q=q, index="D", use_values_in_tab=use_values_in_tab) elif div_type == "func": alphadiv = func_alpha(tab_use, distmat=distmat, q=q, index='D', use_values_in_tab=use_values_in_tab) else: raise RuntimeError(f"Unsupported div_type: {div_type}") # Expect alphadiv to be array-like or Series aligned to smplist try: df.loc[q, smplist] = alphadiv except Exception as e: raise ValueError(f"Failed to assign alpha diversity values for q={q}. " f"Ensure the backend returns values aligned to samples. Error: {e}") # --- Plotting -------------------------------------------------------------- plt.rcParams.update({"font.size": fontsize}) fig, ax = plt.subplots(figsize=figsize) # Determine colors if colorlist is None: try: colorlist = get_colors_markers("colors") except Exception: # Fallback to Matplotlib default color cycle colorlist = plt.rcParams.get("axes.prop_cycle", None) colorlist = colorlist.by_key()["color"] if colorlist is not None else ["C0", "C1", "C2", "C3", "C4"] seen_groups: List[str] = [] for s in df.columns: group = meta.loc[s, color_by] if (color_by is not None) else s # Legend label de-duplication label = group if group not in seen_groups else "_nolegend_" if group not in seen_groups: seen_groups.append(group) colnr = seen_groups.index(group) # stable order by first appearance col = colorlist[colnr % len(colorlist)] if ylog: ax.semilogy(df.index, df[s].values, lw=1, color=col, label=label) else: ax.plot(df.index, df[s].values, lw=1, color=col, label=label) # Axis labels if div_type == "naive": ax.set_ylabel(r"Diversity ($^{q}$D)") elif div_type == "phyl": ax.set_ylabel(r"Phyl. diversity ($^{q}$D)") elif div_type == "func": ax.set_ylabel(r"Func. diversity ($^{q}$D)") ax.set_xlabel("Diversity order (q)") # Ticks and limits based on q_range tick_step = 0.5 xticks = np.arange(q_start, q_end + 1e-12, tick_step) ax.set_xticks(xticks) ax.set_xlim(q_start, q_end) plt.legend(bbox_to_anchor=(1, 1), loc="upper left", frameon=False) plt.tight_layout() # Saving if savename: plt.savefig(savename, dpi=240) try: plt.savefig(f"{savename}.pdf", format="pdf") except Exception: # Fallback silently if a PDF backend is not available in the environment pass return fig, ax, df
# ----------------------------------------------------------------------------- # Plot beta diversity profiles # -----------------------------------------------------------------------------
[docs] def beta_diversity_profile( obj: Union[Dict[str, Any], Any], *, q_range: Tuple[float, float] = (0.0, 2.0), q_step: float = 0.05, distmat: Optional[pd.DataFrame] = None, group_by: Optional[str] = None, order: Optional[str] = None, dis: bool = True, viewpoint: str = "regional", # {'local', 'regional'} only used when dis=True div_type: str = "naive", ylog: bool = False, figsize: Tuple[float, float] = (18 / 2.54, 14 / 2.54), fontsize: int = 10, colorlist: Optional[List[str]] = None, savename: Optional[str] = None, drop_na_groups: bool = True, ) -> Tuple["plt.Figure", "plt.Axes", pd.DataFrame]: """ Plot multi-sample β-diversity (or its dissimilarity transform) vs diversity order q. This function evaluates β_q across a grid of q-values using `naive_multi_beta` and plots a curve per group (or a single curve labeled 'all' if `group_by=None`). It can optionally convert β to dissimilarity using the "local" or "regional" viewpoints (as returned by `naive_multi_beta`). Parameters ---------- obj : dict or MicrobiomeData-like Must support `get_df(obj, 'tab')` -> DataFrame (features × samples) and `get_df(obj, 'meta')` -> DataFrame (sample metadata). q_range : (float, float), default=(0.0, 2.0) Inclusive (start, end) range of q-values. q_step : float, default=0.05 Step size between consecutive q values. Must be positive. distmat : pandas.DataFrame, optional Functional distance matrix (features x features). Required when `div_type='func'`. group_by : str or None, default=None Metadata column defining groups of samples. If None, treats all samples as one group. order : str or None, default=None Metadata column used to sort samples before computing group order. The order of first appearance of groups in the (optionally) sorted metadata determines the plotting order. dis : bool, default=True If True, plot dissimilarity instead of raw β. Uses the `viewpoint` column ('local_dis' or 'regional_dis') returned by `naive_multi_beta`. viewpoint : {'local', 'regional'}, default='regional' Which dissimilarity column to use when `dis=True`. div_type : {'naive', 'phyl', 'func'}, default='naive' Diversity type: - 'naive': taxonomic alpha diversity. - 'phyl' : phylogenetic alpha diversity (requires ``tree`` in `obj`). - 'func' : functional alpha diversity (requires `distmat`). ylog : bool, default=False If True, use a logarithmic y-scale. Note that dissimilarities may include zeros, which cannot be shown on a log scale; such points will be omitted. figsize : (float, float), default=(18/2.54, 14/2.54) Figure size in inches. fontsize : int, default=10 Base font size. colorlist : list of str or None, default=None Colors for groups. If None, uses Matplotlib's default color cycle. savename : str or None, default=None If provided, saves the figure as `savename` (raster) and `savename + '.pdf'`. drop_na_groups : bool, default=True If True, drops groups that are entirely NaN across all q (e.g., groups with <2 samples). Returns ------- fig : matplotlib.figure.Figure The created figure. ax : matplotlib.axes.Axes The matplotlib Axes object for the figure. df : pandas.DataFrame DataFrame with rows = q-values and columns = groups. Contains β (if `dis=False`) or dissimilarity (if `dis=True`) for each group at each q. """ # --- Validate input ------------------------------------------------------- tab = get_df(obj, "tab") meta = get_df(obj, "meta") if not isinstance(tab, pd.DataFrame): raise ValueError("`obj` must provide a 'tab' DataFrame via get_df(obj, 'tab').") if not isinstance(meta, pd.DataFrame): raise ValueError("`obj` must provide a 'meta' DataFrame via get_df(obj, 'meta').") if tab.shape[1] < 2: raise ValueError("At least two samples are required to compute multi-sample β-diversity.") # q-range if not isinstance(q_range, (tuple, list)) or len(q_range) != 2: raise ValueError("`q_range` must be a (start, end) tuple of floats.") q_start, q_end = float(q_range[0]), float(q_range[1]) if q_start > q_end: raise ValueError("`q_range` must satisfy start <= end.") if q_step <= 0: raise ValueError("`q_step` must be positive.") # viewpoint selection for dissimilarity viewpoint = str(viewpoint).lower() if dis: if viewpoint not in {"local", "regional"}: raise ValueError("`viewpoint` must be 'local' or 'regional' when `dis=True`.") ycol = "local_dis" if viewpoint == "local" else "regional_dis" y_label = f"Dissimilarity ({viewpoint})" else: ycol = "beta" y_label = r"Multi-sample $\beta_q$" # Determine plotting order via metadata if order is not None: if order not in meta.columns: raise KeyError(f"`order` column '{order}' not found in metadata.") meta_sorted = meta.sort_values(by=order) else: meta_sorted = meta # Build the group order based on first appearance in the (optionally) sorted meta if group_by is None: groups_order = ["all"] else: if group_by not in meta.columns: raise KeyError(f"`group_by` column '{group_by}' not found in metadata.") # Preserve order of first appearance groups_order = pd.unique(meta_sorted[group_by].astype(str)) # Prepare q grid and results container q_values = np.arange(q_start, q_end + (q_step / 2.0), q_step) # inclusive of end df = pd.DataFrame(index=np.round(q_values, 10), columns=groups_order, dtype=float) # --- Compute β (or dissimilarity) curves ---------------------------------- for q in df.index: if div_type == "naive": betadf = naive_multi_beta(obj, by=group_by, q=float(q)) elif div_type == "phyl": betadf = phyl_multi_beta(obj, by=group_by, q=float(q)) elif div_type == "func": betadf = func_multi_beta(obj, distmat, by=group_by, q=float(q)) else: # Defensive programming (unreachable due to earlier validation) raise RuntimeError(f"Unsupported div_type: {div_type}") # Reindex to our desired plotting order y = betadf[ycol].reindex(groups_order) # Optionally drop groups that have <2 samples => NaN, but do that once at the end df.loc[q, groups_order] = y.values # Drop all-NaN groups if requested (e.g., groups with <2 samples throughout) if drop_na_groups: all_nan_cols = df.columns[df.isna().all()] if len(all_nan_cols) > 0: df = df.drop(columns=all_nan_cols) # Update groups order to match what remains groups_order = [g for g in groups_order if g in df.columns] # --- Plotting -------------------------------------------------------------- plt.rcParams.update({"font.size": fontsize}) fig, ax = plt.subplots(figsize=figsize) # Colors if colorlist is None: prop_cycle = plt.rcParams.get("axes.prop_cycle", None) colorlist = prop_cycle.by_key()["color"] if prop_cycle is not None else ["C0", "C1", "C2", "C3", "C4"] # Draw one line per group for i, grp in enumerate(groups_order): if grp not in df.columns: continue yvals = df[grp].values # For log scale, omit non-positive values silently (Matplotlib will skip them) if ylog: ax.semilogy(df.index.values, yvals, lw=1.5, color=colorlist[i % len(colorlist)], label=str(grp)) else: ax.plot(df.index.values, yvals, lw=1.5, color=colorlist[i % len(colorlist)], label=str(grp)) # Axes and labels ax.set_xlabel("Diversity order (q)") ax.set_ylabel(y_label) # Ticks and limits xticks = np.arange(q_start, q_end + 1e-12, 0.5) ax.set_xticks(xticks) ax.set_xlim(q_start, q_end) # Legend ax.legend(bbox_to_anchor=(1, 1), loc="upper left", frameon=False, title=(group_by if group_by else "Group")) plt.tight_layout() # Save if savename: plt.savefig(savename, dpi=240) try: plt.savefig(f"{savename}.pdf", format="pdf") except Exception: # Fallback silently if a PDF backend is not available in the environment pass return fig, ax, df