Coverage for src / sparkle / platform / latex.py: 98%
42 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
1"""Helper classes/method for LaTeX and bibTeX."""
3import math
5import numpy as np
6import pandas as pd
7import plotly
8import plotly.express as px
9import pylatex as pl
10import kaleido
12kaleido.get_chrome_sync() # Ensure chrome is available for Kaleido
15class AutoRef(pl.base_classes.CommandBase):
16 """AutoRef command for PyLateX."""
18 _latex_name = "autoref"
19 packages = [pl.Package("hyperref")]
22def comparison_plot(
23 data_frame: pd.DataFrame, title: str = None
24) -> plotly.graph_objects.Figure:
25 """Creates a comparison plot from the given data frame.
27 The first column is used for the x axis, the second for the y axis.
29 Args:
30 data_frame: The data frame with the data
31 x_label: The label for the x axis
32 y_label: The label for the y axis
33 title: The title of the plot
34 output_plot: The path where the plot should be written to
36 Returns:
37 The plot object
38 """
39 from scipy import stats
41 x_values = data_frame[data_frame.columns[0]].to_numpy(dtype=float)
42 y_values = data_frame[data_frame.columns[1]].to_numpy(dtype=float)
43 valid_mask = ~(np.isnan(x_values) | np.isnan(y_values))
45 # Filter out NaN values
46 x_values = x_values[valid_mask]
47 y_values = y_values[valid_mask]
49 # Guard against degenerate data (single point or no variance) to avoid SciPy warnings.
50 if (
51 x_values.size >= 2
52 and y_values.size >= 2
53 and np.ptp(x_values) > 0
54 and np.ptp(y_values) > 0
55 ):
56 linregress = stats.linregress(x_values, y_values) # Perform linear regression
58 # Determine if log scale is appropriate based on correlation by checking r-value and p-value
59 log_scale = not (linregress.rvalue > 0.65 and linregress.pvalue < 0.05)
60 else:
61 log_scale = False
63 if log_scale and (data_frame < 0).any(axis=None):
64 # Log scale cannot deal with negative and zero values, set to smallest non zero
65 data_frame[data_frame < 0] = np.nextafter(0, 1)
67 # Maximum value should come from the objective?
68 min_value, max_value = data_frame.min(axis=None), data_frame.max(axis=None)
69 # Slightly more than min/max for interpretability
70 if log_scale: # Take next step on log scale
71 max_value = 10 ** math.ceil(math.log(max_value, 10))
72 plot_range = (min_value, max_value)
73 else: # Take previous/next step on linear scale
74 order_magnitude = math.ceil(math.log(max_value, 10))
75 next_step_max = math.ceil(max_value / (10 ** (order_magnitude - 1))) * 10 ** (
76 order_magnitude - 1
77 )
78 plot_range = (0, next_step_max) if min_value > 0 else (min_value, next_step_max)
79 fig = px.scatter(
80 data_frame=data_frame,
81 x=data_frame.columns[0],
82 y=data_frame.columns[1],
83 range_x=plot_range,
84 range_y=plot_range,
85 title=title,
86 log_x=log_scale,
87 log_y=log_scale,
88 width=1000,
89 height=1000,
90 )
91 # Add dividing diagonal
92 fig.add_shape(
93 type="line",
94 x0=0,
95 y0=0,
96 x1=max_value,
97 y1=max_value,
98 line=dict(color="grey", dash="dot", width=1),
99 )
100 # Add maximum lines
101 fig.add_shape(
102 type="line",
103 opacity=0.7,
104 x0=0,
105 y0=max_value,
106 x1=max_value,
107 y1=max_value,
108 line=dict(color="red", width=1.5, dash="longdash"),
109 )
110 fig.add_shape(
111 type="line",
112 opacity=0.7,
113 x0=max_value,
114 y0=0,
115 x1=max_value,
116 y1=max_value,
117 line=dict(color="red", width=0.5, dash="longdash"),
118 )
119 fig.update_traces(marker=dict(color="RoyalBlue", symbol="x"))
120 fig.update_layout(plot_bgcolor="white", autosize=False, width=1000, height=1000)
121 minor = dict(ticks="inside", ticklen=6, showgrid=True) if log_scale else None
122 # Tick every 10^(log(max)) / 10 for linear scale, log scale is resolved by plotly
123 dtick = 1 if log_scale else 10 ** (math.ceil(math.log(max_value, 10)) - 1)
124 fig.update_xaxes(
125 mirror=True,
126 tickmode="linear",
127 ticks="outside",
128 tick0=0,
129 minor=minor,
130 dtick=dtick,
131 showline=True,
132 linecolor="black",
133 gridcolor="lightgrey",
134 )
135 fig.update_yaxes(
136 mirror=True,
137 tickmode="linear",
138 ticks="outside",
139 tick0=0,
140 minor=minor,
141 dtick=dtick,
142 showline=True,
143 linecolor="black",
144 gridcolor="lightgrey",
145 )
146 return fig