Source code for qdiv.utils.phylo_utils

"""
Functions for phylogenetic trees:
    - parse_newick : returns tree dict
    - tree_to_dataframe : converts tree dict to dataframe
    - dataframe_to_tree : converts dataframe to tree dict
    - subset_tree : subsets tree dict to list of leaf nodes
    - subset_tree_df : subsets dataframe quickly, useful for diversity calculations
    - tree_to_newick : converts tree dict to newick
    - reroot_midpoint : roots a tree dict at midpoint
    - parse_leaves : get a set of leaves descending from a node
    - rename_leaves : rename leaves in dataframe
    - ra_to_branches : get dataframe with each branch and the ra for each sample
    - compute_Tmean : get Tmean for a tree dataframe and featurelist
"""
import pandas as pd
import numpy as np
import ast

__all__ = [
    "parse_newick",
    "tree_to_dataframe",
    "dataframe_to_tree",
    "subset_tree",
    "subset_tree_df",
    "tree_to_newick",
    "reroot_midpoint",
    "parse_leaves",
    "rename_leaves",
    "ra_to_branches",
    "compute_Tmean"
]

AUTO_INTERNAL_PREFIX = "in"

def parse_newick(newick: str):
    """
    Newick parser. Returns a dictionary
    """
    s = newick
    n = len(s)
    i = 0

    # Fast helpers
    def skip_ws(j):
        while j < n and s[j].isspace():
            j += 1
        return j

    def read_until_delim(j):
        """Read a label token until one of '():,;' or whitespace."""
        start = j
        while j < n and s[j] not in '():,;':
            j += 1
        # trim trailing spaces from the chunk
        k = j
        while k > start and s[k-1].isspace():
            k -= 1
        return s[start:k], j

    def read_number(j):
        """Read a float after ':' quickly."""
        start = j
        # include signs, digits, dot, exponent
        while j < n and (s[j].isdigit() or s[j] in '+-.eE'):
            j += 1
        # strip trailing spaces
        k = j
        while k > start and s[k-1].isspace():
            k -= 1
        return float(s[start:k]), j

    # Cheap numeric check for bootstrap
    def looks_numeric(tok: str) -> bool:
        if not tok:
            return False
        c = tok[0]
        if c in '+-.' or c.isdigit():
            # fast path: try float only when plausible
            try:
                float(tok)
                return True
            except Exception:
                return False
        return False

    auto_counter = 1

    # Parsing state: stack holds current internal nodes awaiting children
    stack = []
    root = None

    i = skip_ws(i)
    expecting_child = False  # whether we expect a new child after comma or '('

    while i < n:
        ch = s[i]

        if ch == '(':
            # Start new internal node frame
            node = {"name": None, "length": None, "children": [], "parent": None}
            if stack:
                # attach to parent immediately
                parent = stack[-1]
                node["parent"] = parent
                parent["children"].append(node)
            stack.append(node)
            if root is None:
                root = node
            i += 1
            i = skip_ws(i)
            expecting_child = True

        elif ch == ',':
            # Next sibling
            i += 1
            i = skip_ws(i)
            expecting_child = True

        elif ch == ')':
            # Close current internal node; may have label and/or length
            i += 1
            i = skip_ws(i)
            label = None

            # Optional label (could be bootstrap)
            if i < n and s[i] not in ':,);':
                label, i = read_until_delim(i)
                i = skip_ws(i)
                if looks_numeric(label):
                    label = None  # treat as bootstrap

            # Optional branch length
            length = None
            if i < n and s[i] == ':':
                i += 1
                i = skip_ws(i)
                length, i = read_number(i)
                i = skip_ws(i)

            node = stack.pop()
            node["length"] = length
            if label is None:
                node["name"] = f"{AUTO_INTERNAL_PREFIX}{auto_counter}"
                auto_counter += 1
            else:
                node["name"] = label

            expecting_child = False

        elif ch == ';':
            # End of tree
            break

        else:
            # Leaf: read label [name][ :length ]
            if not expecting_child and not stack:
                # tolerate stray whitespace
                if ch.isspace():
                    i += 1
                    continue
                raise ValueError("Unexpected token while not inside a subtree")

            name, i = read_until_delim(i)
            i = skip_ws(i)

            length = None
            if i < n and s[i] == ':':
                i += 1
                i = skip_ws(i)
                length, i = read_number(i)
                i = skip_ws(i)

            leaf = {"name": name, "length": length, "children": [], "parent": None}
            parent = stack[-1] if stack else None
            if parent is None:
                # Degenerate single-node tree without parentheses
                root = leaf
            else:
                leaf["parent"] = parent
                parent["children"].append(leaf)

            expecting_child = False

    # Optional trailing ';'
    # Ensure root exists
    if root is None:
        raise ValueError("Empty or invalid Newick string")

    # Root length should be None
    root["length"] = None
    return root

def tree_to_dataframe(tree):
    """
    Convert the dictionary tree structure produced by parse_newick()
    into a DataFrame with columns:
        nodes, leaves, branchL, parent, dist_to_root
    """

    nodes = []
    parents = []
    branch_lengths = []
    leaves = []

    # --- Traversal to collect nodes ---
    def collect(node, parent_name):
        name = node["name"]
        length = node["length"]
        children = node["children"]

        nodes.append(name)
        parents.append(parent_name)
        branch_lengths.append(0.0 if parent_name is None else (length or 0.0))

        if not children:  # leaf
            leaves.append({name})
        else:
            leaves.append(set())   # internal, will fill later

        for c in children:
            collect(c, name)

    collect(tree, parent_name=None)

    df = pd.DataFrame({
        "nodes": nodes,
        "parent": parents,
        "branchL": branch_lengths,
        "leaves": leaves,
    })

    # --- Compute leaf sets bottom-up ---
    children_map = {}
    for node, parent in zip(df["nodes"], df["parent"]):
        if parent is not None:
            children_map.setdefault(parent, []).append(node)

    # Postorder: process children before parent
    order = list(reversed(df.index.tolist()))
    leaf_map = {n: set(df.loc[df["nodes"] == n, "leaves"].values[0]) for n in df["nodes"]}

    for idx in order:
        n = df.at[idx, "nodes"]
        if n in children_map:
            merged = set()
            for c in children_map[n]:
                merged |= leaf_map[c]
            leaf_map[n] = merged

    df["leaves"] = df["nodes"].map(leaf_map)

    # --- Distance to root ---
    name_to_idx = {n: i for i, n in enumerate(df["nodes"])}
    parent_idx = np.full(len(df), -1, dtype=int)

    for i, (n, p) in enumerate(zip(df["nodes"], df["parent"])):
        if p in name_to_idx:
            parent_idx[i] = name_to_idx[p]

    branch_len = df["branchL"].to_numpy(float)
    dist = np.zeros(len(df))

    roots = np.where(parent_idx == -1)[0]
    stack = list(roots)

    while stack:
        p = stack.pop()
        children = np.where(parent_idx == p)[0]
        for c in children:
            dist[c] = dist[p] + branch_len[c]
            stack.append(c)

    df["dist_to_root"] = dist
    return df


def dataframe_to_tree(df):
    """
    Convert a tree DataFrame (nodes, parent, branchL, leaves, dist_to_root)
    back into the dictionary tree structure.

    Returns:
        A tree dict:
        {
            "name": str,
            "length": float or None,
            "children": [...],
            "parent": None
        }
    """

    df = df.copy()

    # --- Normalize nodes column ---
    df["nodes"] = df["nodes"].apply(
        lambda x: None if pd.isna(x) else str(x).strip()
    )

    # Collect nodes AFTER normalization
    all_nodes = set(df["nodes"])

    # --- Normalize parent values ---
    def norm_parent(p):
        if p is None or pd.isna(p):
            return None
        s = str(p).strip()
        if s == "" or s.lower() in {"nan", "none"}:
            return None
        return s

    df["parent"] = df["parent"].apply(norm_parent)


    # --- Identify roots ---
    roots = df.loc[df["parent"].isna(), "nodes"].tolist()

    # --- If zero or multiple roots → synthetic root ---
    if len(roots) != 1:
        synthetic = "__synthetic_root__"
        while synthetic in all_nodes:
            synthetic += "_x"

        # Add root
        df = pd.concat([
            df,
            pd.DataFrame({
                "nodes": [synthetic],
                "parent": [None],
                "branchL": [None],
                "leaves": [set()],
                "dist_to_root": [0.0],
            })
        ], ignore_index=True)

        # Attach *existing* roots under synthetic
        for r in roots:
            df.loc[df["nodes"] == r, "parent"] = synthetic

        roots = [synthetic]
        all_nodes.add(synthetic)

    # --- Build recursive node dict ---
    rec_nodes = {}
    for _, row in df.iterrows():
        name = row["nodes"]
        parent = row["parent"]
        length = row["branchL"] if parent is not None else None

        rec_nodes[name] = {
            "name": name,
            "length": length,
            "children": [],
            "parent": None,
        }

    # --- Attach children safely ---
    for _, row in df.iterrows():
        n = row["nodes"]
        p = row["parent"]
        if p is not None and p in rec_nodes:
            rec_nodes[p]["children"].append(rec_nodes[n])
            rec_nodes[n]["parent"] = rec_nodes[p]

    return rec_nodes[roots[0]]

# Subset tree works on dictionary
def subset_tree(tree, keep_leaves):
    """
    Return a pruned version of the tree,
    keeping only branches that contain at least one leaf in keep_leaves.
    Leaves not in keep_leaves are removed.
    Internal nodes with no remaining children are pruned away.
    """

    keep_leaves = set(keep_leaves)

    def prune(node):
        # Leaf?
        if not node["children"]:
            return node if node["name"] in keep_leaves else None

        # Internal node: prune children
        pruned_children = []
        for c in node["children"]:
            child = prune(c)
            if child is not None:
                pruned_children.append(child)

        if not pruned_children:
            return None  # no surviving children → remove node

        # Build a new node
        return {
            "name": node["name"],
            "length": node["length"],
            "children": pruned_children,
            "parent": None   # parent links will be refreshed later
        }

    pruned = prune(tree)

    # Root special case: if nothing kept → return a root-only empty tree
    if pruned is None:
        return {
            "name": tree["name"],
            "length": None,
            "children": [],
            "parent": None,
        }

    # Fix parents after pruning
    def fix_parents(node, parent=None):
        node["parent"] = parent
        for c in node["children"]:
            fix_parents(c, node)

    fix_parents(pruned, None)

    return pruned

#Tree to newick works on tree dictionary
def tree_to_newick(node, *, precision=6):
    """
    Convert a recursive tree structure back into a Newick-formatted string.

    Args:
        node: dict with keys:
              - "name": label (str)
              - "length": branch length (float or None)
              - "children": list of child nodes
        precision: number of digits after decimal for branch lengths

    Returns:
        A Newick string ending with a semicolon.
    """

    # Format branch length
    def fmt_length(length):
        if length is None:
            return ""
        return f":{length:.{precision}f}"

    # Recursively traverse children
    def to_newick(n):
        if not n["children"]:  # leaf
            return n["name"] + fmt_length(n["length"])

        # internal node
        children_newick = ",".join(to_newick(c) for c in n["children"])
        lbl = n["name"] if n["name"] is not None else ""
        return f"({children_newick}){lbl}{fmt_length(n['length'])}"

    return to_newick(node) + ";"

def collapse_single_child_nodes(node):
    """
    Collapse internal nodes that have exactly one child.
    Branch lengths are added (parent length + child length).

    Returns the collapsed node (root may change).
    """

    # First collapse children (postorder)
    for i, child in enumerate(node["children"]):
        node["children"][i] = collapse_single_child_nodes(child)

    # Now collapse this node if it has exactly one child
    while len(node["children"]) == 1 and node["parent"] is not None:
        child = node["children"][0]

        # Combine branch lengths
        new_length = 0.0
        if node["length"] is not None:
            new_length += node["length"]
        if child["length"] is not None:
            new_length += child["length"]

        # Replace this node with its child
        child["length"] = new_length
        child["parent"] = node["parent"]

        # Reattach child in parent's list
        parent = node["parent"]
        for idx, sib in enumerate(parent["children"]):
            if sib is node:
                parent["children"][idx] = child
                break

        # Continue collapsing upward
        node = child

    return node

def _reroot_at_node(root, target_name: str):
    """
    Re-root a recursive tree at the node with name == target_name.

    Tree node structure (as in parse_newick_recursive / tree_to_dataframe):
        {
          "name": str,
          "length": float or None,   # distance from parent to this node
          "children": [ ... ],
          "parent": <node or None>
        }

    Notes
    -----
    • Mutates the tree in place and returns the new root node.
    • Edge lengths are preserved: after re-rooting, every undirected edge
      retains the same numeric length, just oriented to the new parent.
    • The new root's 'length' is set to None.

    Raises
    ------
    ValueError: if target_name is not found or the tree is invalid.
    """

    # ---- helpers -----------------------------------------------------------
    def find_node(n, name):
        stack = [n]
        seen = set()
        while stack:
            x = stack.pop()
            if id(x) in seen:
                continue
            seen.add(id(x))
            if x["name"] == name:
                return x
            stack.extend(x["children"])
        return None

    def path_to_root(n):
        path = []
        seen = set()
        cur = n
        while cur is not None:
            if id(cur) in seen:
                raise ValueError("Cycle detected while walking to root.")
            seen.add(id(cur))
            path.append(cur)
            cur = cur["parent"]
        return path  # [target, ..., old_root]

    # ---- validate single-root invariant -----------------------------------
    def count_roots(n):
        # Count nodes with parent == None
        stack, seen = [n], set()
        roots = set()
        while stack:
            x = stack.pop()
            if id(x) in seen:
                continue
            seen.add(id(x))
            if x["parent"] is None:
                roots.add(id(x))
            stack.extend(x["children"])
        return len(roots)

    # ---- main --------------------------------------------------------------
    # Locate target
    target = find_node(root, target_name)
    if target is None:
        raise ValueError(f"Target node '{target_name}' not found in tree.")

    # Already rooted here?
    if target["parent"] is None:
        return target  # nothing to do

    # Path from target up to old root
    path = path_to_root(target)            # [target, ..., old_root]
    edge_lengths = [node["length"] for node in path[:-1]]  # edge i is between path[i] and path[i+1]

    # Flip orientation along the path:
    # for each (child = path[i], parent = path[i+1]):
    for i in range(len(path) - 1):
        child = path[i]
        parent = path[i + 1]

        # 1) detach child from parent.children
        parent["children"] = [c for c in parent["children"] if c is not child]

        # 2) attach parent as a child of child (will reassign parent pointers below)
        child["children"].append(parent)

    # Rebuild parent pointers along the path (new orientation)
    path[0]["parent"] = None  # new root
    for j in range(1, len(path)):
        path[j]["parent"] = path[j - 1]

    # Assign corrected branch lengths:
    # new root has no parent -> length None
    path[0]["length"] = None
    for j in range(1, len(path)):
        path[j]["length"] = edge_lengths[j - 1]

    # Optional sanity check: still exactly one root?
    # (skip for speed if you wish)
    if count_roots(path[0]) != 1:
        raise ValueError("Re-rooting invariant violated: the tree has multiple roots.")

    return path[0]


def reroot_midpoint(root, *, name_hint="in_midroot", tol=1e-12):
    """
    Midpoint re-root a recursive tree.

    Node structure (as before):
        {
          "name": str,
          "length": float or None,  # distance from parent to this node
          "children": [ ... ],
          "parent": <node or None>
        }

    Returns
    -------
    new_root : dict
        The new root node (with parent=None and length=None).

    Notes
    -----
    - Works for any positive branch lengths (zero lengths tolerated).
    - Inserts a new internal node if the midpoint lies within an edge.
    - Uses reroot_at_node(...) to flip orientation along the chosen path.
    """

    # ----------------- helpers -----------------
    def collect_all_nodes(n):
        out, seen = [], set()
        stack = [n]
        while stack:
            x = stack.pop()
            if id(x) in seen:
                continue
            seen.add(id(x))
            out.append(x)
            if x["parent"] is not None:
                stack.append(x["parent"])
            stack.extend(x["children"])
        return out

    def is_leaf(n):
        return len(n["children"]) == 0

    def edge_length(a, b):
        """Undirected length for edge (a,b) where one is the other's parent."""
        if b is a["parent"]:
            return 0.0 if a["length"] is None else float(a["length"])
        if a is b["parent"]:
            return 0.0 if b["length"] is None else float(b["length"])
        raise ValueError("Nodes are not adjacent in the tree")

    def neighbors(n):
        out = list(n["children"])
        if n["parent"] is not None:
            out.append(n["parent"])
        return out

    def farthest_leaf_from(start):
        """
        Return (leaf_node, dist_map, prev_map).
        Since this is a tree (no cycles, unique paths), a single DFS/BFS works.
        """
        start_id = id(start)
        dist = {start_id: 0.0}
        prev = {start_id: None}
        byid = {start_id: start}

        stack = [start]
        seen = set()

        while stack:
            u = stack.pop()
            uid = id(u)
            if uid in seen:
                continue
            seen.add(uid)
            for v in neighbors(u):
                vid = id(v)
                if vid in dist:  # already discovered via the unique path
                    continue
                w = edge_length(u, v)
                dist[vid] = dist[uid] + w
                prev[vid] = uid
                byid[vid] = v
                stack.append(v)

        # among leaves, pick the farthest
        far_leaf = None
        far_d = -1.0
        for node in byid.values():
            if is_leaf(node):
                d = dist[id(node)]
                if d > far_d:
                    far_d = d
                    far_leaf = node
        return far_leaf, dist, prev

    def path_between(a, b, prev_from_a):
        """Return the node list along the unique path a->...->b using prev map from 'a'."""
        # Build a map id->node (needed to walk 'prev')
        byid = {}
        stack = [a]
        seen = set()
        while stack:
            x = stack.pop()
            if id(x) in seen:
                continue
            seen.add(id(x))
            byid[id(x)] = x
            if x["parent"] is not None:
                stack.append(x["parent"])
            stack.extend(x["children"])

        # Walk from b back to a via prev map, then reverse
        path_ids = []
        cur_id = id(b)
        while cur_id is not None:
            path_ids.append(cur_id)
            if cur_id == id(a):
                break
            cur_id = prev_from_a.get(cur_id, None)
        if not path_ids or path_ids[-1] != id(a):
            # In rare cases where our quick byid build didn't reach all nodes,
            # rebuild via a full traversal from 'a'.
            # (Should not happen for proper trees.)
            raise ValueError("Failed to reconstruct path between nodes.")
        path_ids.reverse()
        return [byid[i] for i in path_ids]

    def unique_name(root_node, base="in_midroot"):
        """Generate a unique internal name not present in the tree."""
        used = {n["name"] for n in collect_all_nodes(root_node)}
        if base not in used:
            return base
        k = 1
        while f"{base}{k}" in used:
            k += 1
        return f"{base}{k}"

    def insert_node_on_edge(u, v, r_from_u, name):
        """
        Insert a new node M at distance r_from_u from node u along the edge (u, v).
        Adjusts parent/child relations and branch lengths accordingly.
        Returns the newly created node M.
        """
        # Determine orientation and current edge length
        if v is u["parent"]:
            # Edge stored on u.length
            w = 0.0 if u["length"] is None else float(u["length"])
            r = float(r_from_u)
            if not (0.0 < r < w):
                raise ValueError("Split distance must be within the edge (u,parent)")
            # Create M
            M = {"name": name, "length": w - r, "children": [u], "parent": v}
            # Update v.children: replace u with M
            for i, ch in enumerate(v["children"]):
                if ch is u:
                    v["children"][i] = M
                    break
            # Update u
            u["parent"] = M
            u["length"] = r
            return M

        elif u is v["parent"]:
            # Edge stored on v.length
            w = 0.0 if v["length"] is None else float(v["length"])
            r = float(r_from_u)
            if not (0.0 < r < w):
                raise ValueError("Split distance must be within the edge (parent,v)")
            # Create M
            M = {"name": name, "length": r, "children": [v], "parent": u}
            # Update u.children: replace v with M
            for i, ch in enumerate(u["children"]):
                if ch is v:
                    u["children"][i] = M
                    break
            # Update v
            v["parent"] = M
            v["length"] = w - r
            return M

        else:
            raise ValueError("u and v are not adjacent when inserting a split node.")

    # ----------------- main logic -----------------
    # Degenerate cases
    all_nodes = collect_all_nodes(root)
    if len(all_nodes) == 1:
        # Single node tree -> already 'rooted'
        root["length"] = None
        root["parent"] = None
        return root

    # Pick an arbitrary leaf
    arb_leaf = next((n for n in all_nodes if is_leaf(n)), None)
    if arb_leaf is None:
        # Polytomy with no explicit leaves (shouldn't happen); treat any node as leaf
        arb_leaf = all_nodes[0]

    # First sweep: farthest leaf from arbitrary leaf
    L1, _, _ = farthest_leaf_from(arb_leaf)

    # Second sweep: farthest leaf from L1 (one end of the diameter)
    L2, dist_from_L1, prev_from_L1 = farthest_leaf_from(L1)
    D = dist_from_L1[id(L2)]
    if D <= tol:
        # All-zero or trivial-length tree -> keep as is
        return root

    # Recover the diameter path (sequence of nodes)
    diam_path = path_between(L1, L2, prev_from_L1)

    # Walk along the path to locate the midpoint
    target = D / 2.0
    acc = 0.0
    for i in range(len(diam_path) - 1):
        a = diam_path[i]
        b = diam_path[i + 1]
        w = edge_length(a, b)
        if acc + w + tol < target:
            acc += w
            continue

        # Where is the midpoint relative to edge (a,b)?
        d_from_a = target - acc  # 0 <= d_from_a <= w
        if d_from_a <= tol:
            # Midpoint at (or extremely close to) node 'a'
            return _reroot_at_node(root, a["name"])

        if (w - d_from_a) <= tol:
            # Midpoint at (or extremely close to) node 'b'
            return _reroot_at_node(root, b["name"])

        # Otherwise, the midpoint is inside the edge (a,b):
        # Insert a new node at distance d_from_a from 'a'
        mid_name = unique_name(root, base=name_hint)
        M = insert_node_on_edge(a, b, d_from_a, mid_name)
        # Now reroot at that newly inserted node
        return _reroot_at_node(root, M["name"])

    # Fallback (should not occur): reroot at L1
    return _reroot_at_node(root, L1["name"])

#Get the set of leaves from a dataframe node
def parse_leaves(leaves):
    # Try to parse as a Python literal
    if isinstance(leaves, str):
        try:
            parsed = ast.literal_eval(leaves)
            # Accept list/tuple/set/ndarray; coerce to list
            if isinstance(parsed, (list, tuple, set, np.ndarray)):
                seq = parsed.tolist() if isinstance(parsed, np.ndarray) else list(parsed)
            else:
                # Not a sequence → treat as single string item
                seq = [str(parsed)]
        except Exception:
            # Fallback: CSV-like split, manual cleaning
            cleaned = (
                leaves.replace("[", "")
                .replace("]", "")
                .replace('"', "")
                .replace("'", "")
            )
            seq = [x for x in (t.strip() for t in cleaned.split(",")) if x]
    elif isinstance(leaves, (list, tuple, set, np.ndarray)):
        seq = leaves.tolist() if isinstance(leaves, np.ndarray) else list(leaves)
    elif pd.isna(leaves):
        seq = []
    else:
        seq = [leaves]

    # Keep only non-empty strings
    return {x for x in seq if isinstance(x, str) and x.strip() != ""}

# Rename leaves work on dataframe
def rename_leaves(
    df: pd.DataFrame,
    leaf_dict: dict,
    *,
    allow_partial: bool = True,
    inplace: bool = False,
) -> pd.DataFrame:
    """
    Fast renamer for leaf labels in a DataFrame-based tree created by tree_to_dataframe().
    Preserves your original semantics:
      • Only *tips* are renamed in 'nodes'
      • 'parent' values are also remapped if they happen to be leaf names (rare)
      • Each row's 'leaves' set is renamed accordingly
      • Validates collisions and (optionally) missing mapping keys

    Expected columns: 'nodes' (object), 'parent' (object), 'branchL' (float), 'leaves' (set of str)

    Parameters
    ----------
    df : pd.DataFrame
    leaf_dict : dict {old_leaf_name -> new_leaf_name}
    allow_partial : bool
        If False, raise if some mapping keys are not present among leaves.
    inplace : bool
        If True mutate df, else return a copy.

    Returns
    -------
    pd.DataFrame
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame")
    if not isinstance(leaf_dict, dict):
        raise TypeError("leaf_dict must be a dict {old->new}")

    required_cols = {"nodes", "parent", "branchL", "leaves"}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Tree DataFrame missing columns: {sorted(missing)}")

    T = df if inplace else df.copy(deep=True)

    # --- 1) Ensure 'leaves' column contains sets (fast-path if already true) ---
    # tree_to_dataframe() populates sets, so this is typically a no-op.
    # Only repair rows whose value isn't a set to avoid O(n) conversions.
    if not all(isinstance(x, set) for x in T["leaves"].values):
        T["leaves"] = T["leaves"].apply(
            lambda x: x if isinstance(x, set)
            else (set(x) if isinstance(x, (list, tuple)) else ({x} if isinstance(x, str) else set()))
        )

    # --- 2) Gather *all leaves* (union of sets) in one pass for validation ---
    # Use an iterative union to avoid building a giant intermediate list
    all_leaves = set()
    for s in T["leaves"].values:
        all_leaves |= s

    # --- 3) Validation: presence & collisions ---
    if not allow_partial:
        missing_keys = set(leaf_dict.keys()) - all_leaves
        if missing_keys:
            raise ValueError(f"Some mapping keys are not present among leaves: {sorted(missing_keys)}")

    # Prevent mapping two different old leaves to the same new name
    reverse = {}
    for old, new in leaf_dict.items():
        if new in reverse and reverse[new] != old:
            raise ValueError(f"Mapping would collide: '{old}' and '{reverse[new]}' -> '{new}'")
        reverse[new] = old

    # Determine which node names are *tips*: they are leaves that are not a parent anywhere
    parents = set(p for p in T["parent"].dropna().values)
    # Tip test: in all_leaves and not a parent
    # Build a mask for tips in 'nodes'
    # We avoid astype(str): 'nodes' are strings already in data produced by tree_to_dataframe()
    nodes_vals = T["nodes"].values
    is_tip_mask = [(n in all_leaves) and (n not in parents) for n in nodes_vals]

    # Check for post-rename duplicates among tip names (projected)
    projected = set()
    for n, is_tip in zip(nodes_vals, is_tip_mask):
        if not is_tip:
            continue
        new_name = leaf_dict.get(n, n)
        if new_name in projected:
            raise ValueError(f"Renaming would create duplicate tip name '{new_name}'")
        projected.add(new_name)

    # --- 4) Apply renaming ---
    # 4a) 'nodes' for tips only
    # Build a Series for minimal assignment
    if any(is_tip_mask):
        idx = [i for i, m in enumerate(is_tip_mask) if m]
        # Map only where needed to avoid touching internal nodes
        T_nodes = T["nodes"].values.copy()
        for i in idx:
            old = T_nodes[i]
            T_nodes[i] = leaf_dict.get(old, old)
        T["nodes"] = T_nodes

    # 4b) 'parent' column: if a parent name equals a leaf being renamed, remap it
    # (rare in rooted trees but harmless)
    if not T["parent"].isna().all():
        T_parent = T["parent"].values.copy()
        for i, p in enumerate(T_parent):
            if p is not None and p in leaf_dict:
                T_parent[i] = leaf_dict[p]
        T["parent"] = T_parent

    # 4c) Rename items inside each row's 'leaves' set
    # Apply in-place to avoid creating new set objects if possible
    def _rename_set(s: set) -> set:
        if not s:
            return s
        # Fast path: detect if any element is in mapping; if not, return original set
        if not (s & set(leaf_dict)):
            return s
        # Else, rebuild a new set with mapped names
        return {leaf_dict.get(x, x) for x in s}

    # Since 'leaves' holds sets, we replace only rows that change to keep data movement minimal
    new_leaves = []
    changed_any = False
    for s in T["leaves"].values:
        new_s = _rename_set(s)
        new_leaves.append(new_s)
        changed_any |= (new_s is not s)  # bool OR

    if changed_any:
        T["leaves"] = new_leaves

    return T

#Subset a tree dataframe for diversity functions
def subset_tree_df(tree_df: pd.DataFrame, keep_leaves) -> pd.DataFrame:
    keep = set(keep_leaves)
    T = tree_df.copy()
    # ensure sets (tree_to_dataframe already returns sets)
    if not all(isinstance(x, set) for x in T["leaves"].values):
        T["leaves"] = T["leaves"].apply(lambda x: set(x) if not isinstance(x, set) else x)
    mask = T["leaves"].apply(lambda s: len(s & keep) > 0)
    T = T.loc[mask].copy()
    T["leaves"] = T["leaves"].apply(lambda s: s & keep)
    return T

# Get ra for each sample and each branch
[docs] def ra_to_branches(ra: pd.DataFrame, tree_df: pd.DataFrame) -> pd.DataFrame: """Return tree2 = (branches × samples) relative-abundance table.""" n_branches, n_samples = tree_df.shape[0], ra.shape[1] A = np.zeros((n_branches, n_samples), dtype=float) # Pre-take numpy view of RA in same sample order ra_vals = ra.to_numpy(dtype=float, copy=False) # Map each LEAF (row in RA) to branch rows and add its vector leaf_list = ra.index.to_list() leaf_pos = {leaf: i for i, leaf in enumerate(leaf_list)} idx_map = {} for row_idx, s in enumerate(tree_df["leaves"].values): for leaf in s: idx_map.setdefault(leaf, []).append(row_idx) # convert to arrays for vectorized adds for k, v in idx_map.items(): idx_map[k] = np.asarray(v, dtype=np.int32) for leaf, rows in idx_map.items(): pos = leaf_pos.get(leaf) if pos is None: # leaf not in abundance table continue A[rows, :] += ra_vals[pos, :] return pd.DataFrame(A, index=tree_df.index, columns=ra.columns)
[docs] def compute_Tmean( tree_df: pd.DataFrame, abund: pd.DataFrame, ) -> pd.Series: """ Compute sample-specific mean tree height T (Chao et al. 2010). T_j = sum_b L_b * a_{b,j} Parameters ---------- tree_df : pd.DataFrame Tree dataframe with column 'branchL' abund : pd.DataFrame Branch × sample matrix of descendant relative abundances (a_{b,j}) Returns ------- pd.Series One T value per sample (indexed by sample name) """ L = tree_df["branchL"].to_numpy()[:, None] # (branches × 1) A = abund.to_numpy() # (branches × samples) T = (L * A).sum(axis=0) return pd.Series(T, index=abund.columns, name="Tmean")
def ladderize_tree_df(df, *, right=True): """ Ladderize a DataFrame-based tree (as produced by tree_to_dataframe). right=True → larger clade lower in plotting (right-heavy) right=False → smaller clade lower """ df = df.copy() # Ensure leaves are sets if not all(isinstance(x, set) for x in df["leaves"]): df["leaves"] = df["leaves"].apply(lambda x: set(x) if not isinstance(x, set) else x) # Build children map children = {} for n, p in zip(df["nodes"], df["parent"]): if p is not None: children.setdefault(p, []).append(n) # Subtree size directly from leaves size = dict(zip(df["nodes"], df["leaves"].apply(len))) # DFS rebuild in ladderized order roots = df.loc[df["parent"].isna(), "nodes"].tolist() if len(roots) != 1: raise ValueError("Tree must have exactly one root") root = roots[0] ordered = [] def dfs(n): ordered.append(n) kids = children.get(n, []) if kids: # sort by size kids_sorted = sorted( kids, key=lambda c: size[c], reverse=right, # right=True → large clade last → right-heavy ) for c in kids_sorted: dfs(c) dfs(root) # Reorder the DataFrame rows in the ladderized DFS order df["__order"] = pd.Categorical(df["nodes"], ordered, ordered=True) df = df.sort_values("__order").drop(columns="__order").reset_index(drop=True) return df