Coverage for sparkle/platform/latex.py: 34%
35 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 10:17 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 10:17 +0000
1"""Helper classes/method for LaTeX and bibTeX."""
3import math
5import numpy as np
6import pandas as pd
7import plotly
8import plotly.express as px
9import pylatex as pl
10import kaleido
12kaleido.get_chrome_sync() # Ensure chrome is available for Kaleido
15class AutoRef(pl.base_classes.CommandBase):
16 """AutoRef command for PyLateX."""
18 _latex_name = "autoref"
19 packages = [pl.Package("hyperref")]
22def comparison_plot(
23 data_frame: pd.DataFrame, title: str = None
24) -> plotly.graph_objects.Figure:
25 """Creates a comparison plot from the given data frame.
27 The first column is used for the x axis, the second for the y axis.
29 Args:
30 data_frame: The data frame with the data
31 x_label: The label for the x axis
32 y_label: The label for the y axis
33 title: The title of the plot
34 output_plot: The path where the plot should be written to
36 Returns:
37 The plot object
38 """
39 from scipy import stats
41 # Determine if data is log scale, linregress tells us how linear the data is
42 linregress = stats.linregress(
43 data_frame[data_frame.columns[0]].to_numpy(),
44 data_frame[data_frame.columns[1]].to_numpy(),
45 )
46 log_scale = not (linregress.rvalue > 0.65 and linregress.pvalue < 0.05)
48 if log_scale and (data_frame < 0).any(axis=None):
49 # Log scale cannot deal with negative and zero values, set to smallest non zero
50 data_frame[data_frame < 0] = np.nextafter(0, 1)
52 # Maximum value should come from the objective?
53 min_value, max_value = data_frame.min(axis=None), data_frame.max(axis=None)
54 # Slightly more than min/max for interpretability
55 if log_scale: # Take next step on log scale
56 max_value = 10 ** math.ceil(math.log(max_value, 10))
57 plot_range = (min_value, max_value)
58 else: # Take previous/next step on linear scale
59 order_magnitude = math.ceil(math.log(max_value, 10))
60 next_step_max = math.ceil(max_value / (10 ** (order_magnitude - 1))) * 10 ** (
61 order_magnitude - 1
62 )
63 plot_range = (0, next_step_max) if min_value > 0 else (min_value, next_step_max)
64 fig = px.scatter(
65 data_frame=data_frame,
66 x=data_frame.columns[0],
67 y=data_frame.columns[1],
68 range_x=plot_range,
69 range_y=plot_range,
70 title=title,
71 log_x=log_scale,
72 log_y=log_scale,
73 width=1000,
74 height=1000,
75 )
76 # Add dividing diagonal
77 fig.add_shape(
78 type="line",
79 x0=0,
80 y0=0,
81 x1=max_value,
82 y1=max_value,
83 line=dict(color="grey", dash="dot", width=1),
84 )
85 # Add maximum lines
86 fig.add_shape(
87 type="line",
88 opacity=0.7,
89 x0=0,
90 y0=max_value,
91 x1=max_value,
92 y1=max_value,
93 line=dict(color="red", width=1.5, dash="longdash"),
94 )
95 fig.add_shape(
96 type="line",
97 opacity=0.7,
98 x0=max_value,
99 y0=0,
100 x1=max_value,
101 y1=max_value,
102 line=dict(color="red", width=0.5, dash="longdash"),
103 )
104 fig.update_traces(marker=dict(color="RoyalBlue", symbol="x"))
105 fig.update_layout(plot_bgcolor="white", autosize=False, width=1000, height=1000)
106 minor = dict(ticks="inside", ticklen=6, showgrid=True) if log_scale else None
107 # Tick every 10^(log(max)) / 10 for linear scale, log scale is resolved by plotly
108 dtick = 1 if log_scale else 10 ** (math.ceil(math.log(max_value, 10)) - 1)
109 fig.update_xaxes(
110 mirror=True,
111 tickmode="linear",
112 ticks="outside",
113 tick0=0,
114 minor=minor,
115 dtick=dtick,
116 showline=True,
117 linecolor="black",
118 gridcolor="lightgrey",
119 )
120 fig.update_yaxes(
121 mirror=True,
122 tickmode="linear",
123 ticks="outside",
124 tick0=0,
125 minor=minor,
126 dtick=dtick,
127 showline=True,
128 linecolor="black",
129 gridcolor="lightgrey",
130 )
131 return fig