Source code for proteobench.plotting.plot_generator_denovo

"""
Module for plotting results of de novo models
"""

from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from proteobench.plotting.plot_generator_base import PlotGeneratorBase

EPSILON = 0.0001


[docs] def flatten_results_column(df): results = { "engine": [], "peptide_mass_precision": [], "peptide_mass_recall": [], "peptide_mass_coverage": [], "peptide_exact_precision": [], "peptide_exact_recall": [], "peptide_exact_coverage": [], "aa_mass_precision": [], "aa_mass_recall": [], "aa_mass_coverage": [], "aa_exact_precision": [], "aa_exact_recall": [], "aa_exact_coverage": [], } for i, row in df.iterrows(): results["engine"].append(row["software_name"]) results["peptide_mass_precision"].append(row["results"]["peptide"]["mass"]["precision"]) results["peptide_mass_recall"].append(row["results"]["peptide"]["mass"]["recall"]) results["peptide_mass_coverage"].append(row["results"]["peptide"]["mass"]["coverage"]) results["peptide_exact_precision"].append(row["results"]["peptide"]["exact"]["precision"]) results["peptide_exact_recall"].append(row["results"]["peptide"]["exact"]["recall"]) results["peptide_exact_coverage"].append(row["results"]["peptide"]["exact"]["coverage"]) results["aa_mass_precision"].append(row["results"]["aa"]["mass"]["precision"]) results["aa_mass_recall"].append(row["results"]["aa"]["mass"]["recall"]) results["aa_mass_coverage"].append(row["results"]["aa"]["mass"]["coverage"]) results["aa_exact_precision"].append(row["results"]["aa"]["exact"]["precision"]) results["aa_exact_recall"].append(row["results"]["aa"]["exact"]["recall"]) results["aa_exact_coverage"].append(row["results"]["aa"]["exact"]["coverage"]) return pd.DataFrame(results)
[docs] class DeNovoPlotGenerator(PlotGeneratorBase): """ Plot generator for de novo sequencing data points. Implements the PlotGeneratorBase interface for consistent module plotting. """
[docs] def plot_main_metric(self, result_df: pd.DataFrame, **kwargs) -> go.Figure: """ Generate the main performance metric plot. Parameters ---------- result_df : pd.DataFrame DataFrame containing the results to plot. **kwargs : dict Additional parameters: - level: str (default "precision") - metric type ("precision" or "recall") - evaluation_type: str (default "mass") - evaluation type ("mass" or "exact") - colorblind_mode: bool (default False) - whether to use different shapes for software tools - software_colors: Dict[str, str] - color mapping for software tools - software_markers: Dict[str, str] - marker mapping for software tools (used when colorblind_mode is True) - mapping: Dict[str, int] - size mapping for old/new datapoints - highlight_color: str - color for highlighted datapoints - label: str - label field to display Returns ------- go.Figure The generated plotly figure for the main performance metric. """ # Extract parameters from kwargs with defaults level = kwargs.get("level", "precision") evaluation_type = kwargs.get("evaluation_type", "mass") colorblind_mode = kwargs.get("colorblind_mode", False) software_colors = kwargs.get( "software_colors", { "AdaNovo": "#88ccef", "Casanovo": "#cc6777", "DeepNovo": "#ddcc77", "PepNet": "#147733", "Pi-HelixNovo": "#342288", "Pi-PrimeNovo": "#aa4599", "PEAKS": "#671100", }, ) software_markers = kwargs.get( "software_markers", { "AdaNovo": "circle", "Casanovo": "square", "DeepNovo": "diamond", "PepNet": "cross", "Pi-HelixNovo": "x", "Pi-PrimeNovo": "triangle-up", "PEAKS": "star", }, ) mapping = kwargs.get("mapping", {"old": 10, "new": 20}) highlight_color = kwargs.get("highlight_color", "#d30067") label = kwargs.get("label", "None") # Use result_df as the main dataframe (renamed from benchmark_metrics_df) benchmark_metrics_df = result_df # Define layout results_df = flatten_results_column(benchmark_metrics_df) benchmark_metrics_df = pd.concat([benchmark_metrics_df, results_df], axis=1) results_min = results_df.min() results_max = results_df.max() # Add hover text with detailed information for each data point hover_texts = [] for idx, _ in benchmark_metrics_df.iterrows(): datapoint_text = "" if benchmark_metrics_df.is_temporary[idx] == True: datapoint_text = ( f"ProteoBench ID: {benchmark_metrics_df.id[idx]}<br>" + f"Software tool: {benchmark_metrics_df.software_name[idx]} {benchmark_metrics_df.software_version[idx]}<br>" ) if "comments" in benchmark_metrics_df.columns: comment = benchmark_metrics_df.comments[idx] if isinstance(comment, str): datapoint_text = ( datapoint_text + f"Comment (private submission): {comment[:10] + '...' if len(comment) > 10 else comment}..." ) else: # TODO: Determine parameters based on module datapoint_text = ( f"ProteoBench ID: {benchmark_metrics_df.id[idx]}<br>" + f"Software tool: {benchmark_metrics_df.software_name[idx]} {benchmark_metrics_df.software_version[idx]}<br>" + f"Model checkpoint: {benchmark_metrics_df.checkpoint[idx]}<br>" + f"Number of Beams: {benchmark_metrics_df.n_beams[idx]}<br>" + f"Decoding Strategy: {benchmark_metrics_df.decoding_strategy[idx]}<br>" + f"Precursor Tolerance: {benchmark_metrics_df.precursor_mass_tolerance[idx]}<br>" + f"Tolerance for precursor removal: {benchmark_metrics_df.remove_precursor_tol[idx]}<br>" + f"Number of peaks: {benchmark_metrics_df.n_peaks[idx]} <br>" + f"Min mz: {benchmark_metrics_df.min_mz[idx]}<br>" + f"Max mz: {benchmark_metrics_df.max_mz[idx]}<br>" + f"Min peptide length: {benchmark_metrics_df.min_peptide_length[idx]}<br>" + f"Max peptide length: {benchmark_metrics_df.max_peptide_length[idx]}<br>" + f"Min intensity: {benchmark_metrics_df.min_intensity[idx]}<br>" + f"Max intensity: {benchmark_metrics_df.max_intensity[idx]}<br>" + f"Max precursor charge: {benchmark_metrics_df.max_precursor_charge[idx]}<br>" + f"Isotope error range: {benchmark_metrics_df.isotope_error_range[idx]}<br>" ) if "submission_comments" in benchmark_metrics_df.columns: comment = benchmark_metrics_df.submission_comments[idx] if isinstance(comment, str): datapoint_text = ( datapoint_text + f"Comment (public submission): {comment[:10] + '...' if len(comment) > 10 else comment}..." ) hover_texts.append(datapoint_text) scatter_size = [mapping[item] for item in benchmark_metrics_df["old_new"]] if "Highlight" in benchmark_metrics_df.columns: scatter_size = [ item * 2 if highlight else item for item, highlight in zip(scatter_size, benchmark_metrics_df["Highlight"]) ] # Color plot based on software tool colors = [software_colors[software] for software in benchmark_metrics_df["software_name"]] if "Highlight" in benchmark_metrics_df.columns: colors = [ highlight_color if highlight else item for item, highlight in zip(colors, benchmark_metrics_df["Highlight"]) ] # Set markers based on software tool (if colorblind mode is enabled) markers = [software_markers[software] for software in benchmark_metrics_df["software_name"]] benchmark_metrics_df["color"] = colors benchmark_metrics_df["hover_text"] = hover_texts benchmark_metrics_df["scatter_size"] = scatter_size if colorblind_mode: benchmark_metrics_df["marker"] = markers else: benchmark_metrics_df["marker"] = "circle" layout_xaxis_range = [ results_min[f"peptide_{evaluation_type}_{level}"] - results_min[f"peptide_{evaluation_type}_{level}"] * 0.05, results_max[f"peptide_{evaluation_type}_{level}"] + results_max[f"peptide_{evaluation_type}_{level}"] * 0.05, ] layout_yaxis_range = [ results_min[f"aa_{evaluation_type}_{level}"] - results_min[f"aa_{evaluation_type}_{level}"] * 0.05, results_max[f"aa_{evaluation_type}_{level}"] + results_max[f"aa_{evaluation_type}_{level}"] * 0.05, ] layout_xaxis_title = f"Peptide {level.capitalize()}" layout_yaxis_title = f"Amino Acid {level.capitalize()}" fig = go.Figure( layout_yaxis_range=layout_yaxis_range, layout_xaxis_range=layout_xaxis_range, ) # Get all unique color-software combinations (necessary for highlighting) color_software_combinations = benchmark_metrics_df[["color", "software_name"]].drop_duplicates() # plot the data points, one trace per software tool for _, row in color_software_combinations.iterrows(): color = row["color"] software = row["software_name"] tmp_df = benchmark_metrics_df[ (benchmark_metrics_df["color"] == color) & (benchmark_metrics_df["software_name"] == software) ] fig.add_trace( go.Scatter( x=tmp_df[f"peptide_{evaluation_type}_{level}"], y=tmp_df[f"aa_{evaluation_type}_{level}"], mode="markers" if label == "None" else "markers+text", hovertext=tmp_df["hover_text"], text=tmp_df[label] if label != "None" else None, marker=dict( color=tmp_df["color"], showscale=False, symbol=tmp_df["marker"].iloc[0] if colorblind_mode else "circle", ), marker_size=tmp_df["scatter_size"], name=tmp_df["software_name"].iloc[0], ) ) fig.update_layout( width=700, height=700, xaxis=dict( title=layout_xaxis_title, gridcolor="white", gridwidth=2, linecolor="black", ), yaxis=dict( title=layout_yaxis_title, gridcolor="white", gridwidth=2, linecolor="black", ), ) fig.update_xaxes(showgrid=True, gridcolor="lightgray", gridwidth=1) fig.update_yaxes(showgrid=True, gridcolor="lightgray", gridwidth=1) fig.add_annotation( x=0.5, y=0.5, xref="paper", yref="paper", text="-Alpha-", font=dict(size=50, color="rgba(0,0,0,0.1)"), showarrow=False, ) fig.update_layout(clickmode="event+select") return fig
[docs] def generate_in_depth_plots( self, performance_data: pd.DataFrame, parse_settings: any = None, **kwargs ) -> Dict[str, go.Figure]: """ Generate module-specific in-depth plots. Parameters ---------- performance_data : pd.DataFrame The performance data to plot. parse_settings : any, optional Parse settings for the module (not used by de novo, included for signature compatibility). **kwargs : dict Additional parameters: - mod_labels: List[str] - list of PTM modification labels - mod_label: str - specific PTM label for detailed plot - feature: str - spectrum feature to plot - evaluation_type: str - evaluation type ("mass" or "exact") Returns ------- Dict[str, go.Figure] Dictionary mapping plot names to plotly figures. """ plots = {} # Extract parameters with defaults mod_labels = kwargs.get( "mod_labels", [ "M-Oxidation", "Q-Deamidation", "N-Deamidation", "N-term Acetylation", "N-term Carbamylation", "N-term Ammonia-loss", ], ) mod_label = kwargs.get("mod_label", "M-Oxidation") feature = kwargs.get("feature", "Missing Fragmentation Sites") evaluation_type = kwargs.get("evaluation_type", "mass") software_colors = kwargs.get( "software_colors", { "AdaNovo": "#88ccef", "Casanovo": "#cc6777", "DeepNovo": "#ddcc77", "PepNet": "#147733", "Pi-HelixNovo": "#342288", "Pi-PrimeNovo": "#aa4599", "PEAKS": "#671100", }, ) # Generate PTM plots plots["ptm_overview"] = self.plot_ptm_overview( performance_data, mod_labels=mod_labels, software_colors=software_colors ) plots["ptm_specific"] = self.plot_ptm_specific( performance_data, mod_label=mod_label, software_colors=software_colors ) # Generate spectrum feature plot plots["spectrum_feature"] = self.plot_spectrum_feature( performance_data, feature=feature, evaluation_type=evaluation_type, software_colors=software_colors ) # Generate species plot plots["species_overview"] = self.plot_species_overview( performance_data, evaluation_type=evaluation_type, software_colors=software_colors ) return plots
[docs] def get_in_depth_plot_layout(self) -> list: """ Define the layout configuration for displaying plots. Returns ------- list List of plot configurations for organizing the UI display. """ return [ {"plots": ["ptm_overview", "ptm_specific"], "columns": 2, "title": "PTM Analysis"}, {"plots": ["spectrum_feature"], "columns": 1, "title": "Spectrum Features"}, {"plots": ["species_overview"], "columns": 1, "title": "Species Analysis"}, ]
[docs] def get_in_depth_plot_descriptions(self) -> Dict[str, str]: """ Get descriptions for each in-depth plot. Returns ------- Dict[str, str] Dictionary mapping plot names to their descriptions. """ return { "ptm_overview": "Overview of precision across different post-translational modifications (PTMs). " "Shows how well each tool identifies modified amino acids.", "ptm_specific": "Detailed analysis of a specific PTM, comparing precision between ground truth " "and de novo predictions.", "spectrum_feature": "Analysis of precision relative to spectrum features such as missing " "fragmentation sites, peptide length, or explained intensity.", "species_overview": "Breakdown of precision across different species in the dataset.", }
[docs] def plot_ptm_overview( self, benchmark_metrics_df: pd.DataFrame, mod_labels: List[str], software_colors: Dict[str, str] = { "AdaNovo": "#8b26ff", "Casanovo": "#8bc6fd", "DeepNovo": "#108E2E", "PepNet": "#F89008", "Pi-HelixNovo": "#E43924", "Pi-PrimeNovo": "#663200", "PEAKS": "#f032e6", }, ): fig = go.Figure() for i, row in benchmark_metrics_df.iterrows(): x, y = self.get_modification_scores(row["results"]["in_depth"]["PTM"], mod_labels=mod_labels) tool = row["software_name"] fig.add_trace( go.Scatter(x=x, y=y, mode="lines+markers", name=tool, marker=dict(color=software_colors[tool])) ) fig.update_layout( width=700, height=400, xaxis=dict(title="Modification", color="black", gridwidth=2, linecolor="black"), yaxis=dict(linecolor="black"), ) fig.update_yaxes(title="Precision", color="black", gridwidth=2) fig.update_yaxes(showgrid=True, gridcolor="lightgray", gridwidth=1) return fig
[docs] def plot_ptm_specific( self, benchmark_metrics_df, mod_label, software_colors: Dict[str, str] = { "AdaNovo": "#8b26ff", "Casanovo": "#8bc6fd", "DeepNovo": "#108E2E", "PepNet": "#F89008", "Pi-HelixNovo": "#E43924", "Pi-PrimeNovo": "#663200", "PEAKS": "#f032e6", }, ): fig = go.Figure() for i, row in benchmark_metrics_df.iterrows(): ptm_data = row["results"]["in_depth"]["PTM"] # To make division by 0 impossible x = ptm_data[mod_label]["correct_gt"] / (ptm_data[mod_label]["counts_gt"] + EPSILON) y = ptm_data[mod_label]["correct_dn"] / (ptm_data[mod_label]["counts_dn"] + EPSILON) tool = row["software_name"] fig.add_trace(go.Scatter(x=[x], y=[y], name=tool, marker=dict(color=software_colors[tool]))) fig.update_layout( width=500, height=500, xaxis=dict(title="Precision (Ground-truth)", color="black", gridwidth=2), yaxis=dict(title="Precision (denovo)", color="black", gridwidth=2), ) return fig
[docs] @staticmethod def get_modification_scores(mod_dict, mod_labels): x = [] y = [] for mod_label in mod_labels: x.append(mod_label) y.append(mod_dict[mod_label]["correct_gt"] / mod_dict[mod_label]["counts_gt"] + EPSILON) return x, y
[docs] def plot_spectrum_feature( self, benchmark_metrics_df, feature, evaluation_type="mass", software_colors={ "AdaNovo": "#8b26ff", "Casanovo": "#8bc6fd", "DeepNovo": "#108E2E", "PepNet": "#F89008", "Pi-HelixNovo": "#E43924", "Pi-PrimeNovo": "#663200", "PEAKS": "#f032e6", "test": "black", }, ): # Create a subplot with 2 rows, shared x-axis fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.8, 0.2], vertical_spacing=0, subplot_titles=(f"{feature} vs Precision", None), ) if len(benchmark_metrics_df) == 0: fig.add_trace(go.Scatter()) fig.add_trace(go.Bar()) fig.update_layout( height=600, width=600, xaxis=dict(title=None, color="black"), yaxis=dict(title="Precision", color="black"), xaxis2=dict(title=f"{feature}", color="black"), yaxis2=dict(title="Number of Spectra", color="black"), margin=dict(t=50), ) fig.update_yaxes( autorange="reversed", # tickvals=[-v for v in sorted(set(df['y_bar']))], # ticktext=[v for v in sorted(set(df['y_bar']))], row=2, col=1, ) return fig ### Reformat df benchmark_metrics_df = benchmark_metrics_df.reset_index(drop=True) dtps_to_plot = [x["in_depth"]["Spectrum"][feature] for x in benchmark_metrics_df["results"].tolist()] # Stringify the keys of the datapoint to plot and convert to dataframe df = pd.DataFrame([{str(k): v for k, v in i.items()} for i in dtps_to_plot]) df = df.fillna(str({"exact": 0.0, "mass": 0.0, "n_spectra": 0})) df = ( pd.concat([df, benchmark_metrics_df[["software_name", "id"]]], axis=1) .melt(id_vars=["id", "software_name"]) .rename(columns={"variable": feature, "value": "metrics"}) ) df["metrics"] = df["metrics"].apply(lambda x: eval(x) if isinstance(x, str) else x) ### Create the scatter-lineplot of the feature for dtp_id in df["id"].unique(): df_dtp = df.loc[df["id"] == dtp_id] tool = df_dtp.reset_index().loc[0, "software_name"] fig.add_trace( go.Scatter( x=df_dtp[feature].tolist(), y=df_dtp["metrics"].apply(lambda x: x[evaluation_type]).tolist(), name=tool, marker=dict(color=software_colors.get(tool, "gray")), mode="lines+markers", ), row=1, col=1, ) ### Create the bar chart # Extract the counts as medians for all plotted points bar_data = df.groupby(feature)["metrics"].apply(lambda x: np.median([i["n_spectra"] for i in x])) bar_counts = bar_data.tolist() bar_xaxis = bar_data.index.tolist() # Construct hover text def create_hovertext(df: pd.DataFrame): text = "Number of spectra for each tool" for i, (id, metric) in df[["id", "metrics"]].iterrows(): text += f'<br>{id}: {metric["n_spectra"]}' return text hovertexts = df.groupby(feature).apply(lambda x: create_hovertext(x)).tolist() # Construct the barchart fig.add_trace( go.Bar(x=bar_xaxis, y=bar_counts, hovertext=hovertexts, marker=dict(color="gray"), showlegend=False), row=2, col=1, ) fig.update_layout( height=600, width=600, xaxis=dict(title=None, color="black"), yaxis=dict(title="Precision", color="black"), xaxis2=dict(title=f"{feature}", color="black"), yaxis2=dict(title="Number of Spectra", color="black"), margin=dict(t=50), ) fig.update_yaxes( autorange="reversed", # tickvals=[-v for v in sorted(set(df['y_bar']))], # ticktext=[v for v in sorted(set(df['y_bar']))], row=2, col=1, ) return fig
[docs] def plot_species_overview( self, benchmark_metrics_df, evaluation_type="mass", software_colors={ "AdaNovo": "#8b26ff", "Casanovo": "#8bc6fd", "DeepNovo": "#108E2E", "PepNet": "#F89008", "Pi-HelixNovo": "#E43924", "Pi-PrimeNovo": "#663200", "PEAKS": "#f032e6", "test": "black", }, ): # Create a subplot with 2 rows, shared x-axis fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.8, 0.2], vertical_spacing=0, subplot_titles=("Species vs Precision", None), ) if len(benchmark_metrics_df) == 0: fig.add_trace(go.Scatter()) fig.add_trace(go.Bar()) fig.update_layout( height=600, width=600, xaxis=dict(title=None, color="black"), yaxis=dict(title="Precision", color="black"), xaxis2=dict(title="Species", color="black"), yaxis2=dict(title="Number of Spectra", color="black"), margin=dict(t=50), ) fig.update_yaxes( autorange="reversed", # tickvals=[-v for v in sorted(set(df['y_bar']))], # ticktext=[v for v in sorted(set(df['y_bar']))], row=2, col=1, ) return fig benchmark_metrics_df = benchmark_metrics_df.reset_index(drop=True) df = ( pd.DataFrame([x["in_depth"]["Species"] for x in benchmark_metrics_df["results"].tolist()]) .fillna(str({"exact": 0.0, "mass": 0.0, "n_spectra": 0})) .map(lambda x: eval(x) if isinstance(x, str) else x) ) df = ( pd.concat([df, benchmark_metrics_df[["software_name", "id"]]], axis=1) .melt(id_vars=["id", "software_name"]) .rename(columns={"variable": "Species", "value": "metrics"}) ) df["metrics"] = df["metrics"].apply(lambda x: eval(x) if isinstance(x, str) else x) ### Create the scatter-lineplot of the feature for dtp_id in df["id"].unique(): df_dtp = df.loc[df["id"] == dtp_id] tool = df_dtp.reset_index().loc[0, "software_name"] fig.add_trace( go.Scatter( x=df_dtp["Species"].tolist(), y=df_dtp["metrics"].apply(lambda x: x[evaluation_type]).tolist(), name=tool, marker=dict(color=software_colors.get(tool, "gray")), mode="lines+markers", ), row=1, col=1, ) ### Create the bar chart # Extract the counts as medians for all plotted points bar_data = df.groupby("Species")["metrics"].apply(lambda x: np.median([i["n_spectra"] for i in x])) bar_counts = bar_data.tolist() bar_xaxis = bar_data.index.tolist() # Construct hover text def create_hovertext(df: pd.DataFrame): text = "Number of spectra for each tool" for i, (id, metric) in df[["id", "metrics"]].iterrows(): text += f'<br>{id}: {metric["n_spectra"]}' return text hovertexts = df.groupby("Species").apply(lambda x: create_hovertext(x)).tolist() # Construct the barchart fig.add_trace( go.Bar(x=bar_xaxis, y=bar_counts, hovertext=hovertexts, marker=dict(color="gray"), showlegend=False), row=2, col=1, ) fig.update_layout( height=600, width=600, xaxis=dict(title=None, color="black"), yaxis=dict(title="Precision", color="black"), xaxis2=dict(title="Species", color="black"), yaxis2=dict(title="Number of Spectra", color="black"), margin=dict(t=50), ) fig.update_yaxes( autorange="reversed", # tickvals=[-v for v in sorted(set(df['y_bar']))], # ticktext=[v for v in sorted(set(df['y_bar']))], row=2, col=1, ) return fig
[docs] def plot_species_specific(): pass