Coverage for src / sparkle / platform / latex.py: 98%

42 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 15:31 +0000

1"""Helper classes/method for LaTeX and bibTeX.""" 

2 

3import math 

4 

5import numpy as np 

6import pandas as pd 

7import plotly 

8import plotly.express as px 

9import pylatex as pl 

10import kaleido 

11 

12kaleido.get_chrome_sync() # Ensure chrome is available for Kaleido 

13 

14 

15class AutoRef(pl.base_classes.CommandBase): 

16 """AutoRef command for PyLateX.""" 

17 

18 _latex_name = "autoref" 

19 packages = [pl.Package("hyperref")] 

20 

21 

22def comparison_plot( 

23 data_frame: pd.DataFrame, title: str = None 

24) -> plotly.graph_objects.Figure: 

25 """Creates a comparison plot from the given data frame. 

26 

27 The first column is used for the x axis, the second for the y axis. 

28 

29 Args: 

30 data_frame: The data frame with the data 

31 x_label: The label for the x axis 

32 y_label: The label for the y axis 

33 title: The title of the plot 

34 output_plot: The path where the plot should be written to 

35 

36 Returns: 

37 The plot object 

38 """ 

39 from scipy import stats 

40 

41 x_values = data_frame[data_frame.columns[0]].to_numpy(dtype=float) 

42 y_values = data_frame[data_frame.columns[1]].to_numpy(dtype=float) 

43 valid_mask = ~(np.isnan(x_values) | np.isnan(y_values)) 

44 

45 # Filter out NaN values 

46 x_values = x_values[valid_mask] 

47 y_values = y_values[valid_mask] 

48 

49 # Guard against degenerate data (single point or no variance) to avoid SciPy warnings. 

50 if ( 

51 x_values.size >= 2 

52 and y_values.size >= 2 

53 and np.ptp(x_values) > 0 

54 and np.ptp(y_values) > 0 

55 ): 

56 linregress = stats.linregress(x_values, y_values) # Perform linear regression 

57 

58 # Determine if log scale is appropriate based on correlation by checking r-value and p-value 

59 log_scale = not (linregress.rvalue > 0.65 and linregress.pvalue < 0.05) 

60 else: 

61 log_scale = False 

62 

63 if log_scale and (data_frame < 0).any(axis=None): 

64 # Log scale cannot deal with negative and zero values, set to smallest non zero 

65 data_frame[data_frame < 0] = np.nextafter(0, 1) 

66 

67 # Maximum value should come from the objective? 

68 min_value, max_value = data_frame.min(axis=None), data_frame.max(axis=None) 

69 # Slightly more than min/max for interpretability 

70 if log_scale: # Take next step on log scale 

71 max_value = 10 ** math.ceil(math.log(max_value, 10)) 

72 plot_range = (min_value, max_value) 

73 else: # Take previous/next step on linear scale 

74 order_magnitude = math.ceil(math.log(max_value, 10)) 

75 next_step_max = math.ceil(max_value / (10 ** (order_magnitude - 1))) * 10 ** ( 

76 order_magnitude - 1 

77 ) 

78 plot_range = (0, next_step_max) if min_value > 0 else (min_value, next_step_max) 

79 fig = px.scatter( 

80 data_frame=data_frame, 

81 x=data_frame.columns[0], 

82 y=data_frame.columns[1], 

83 range_x=plot_range, 

84 range_y=plot_range, 

85 title=title, 

86 log_x=log_scale, 

87 log_y=log_scale, 

88 width=1000, 

89 height=1000, 

90 ) 

91 # Add dividing diagonal 

92 fig.add_shape( 

93 type="line", 

94 x0=0, 

95 y0=0, 

96 x1=max_value, 

97 y1=max_value, 

98 line=dict(color="grey", dash="dot", width=1), 

99 ) 

100 # Add maximum lines 

101 fig.add_shape( 

102 type="line", 

103 opacity=0.7, 

104 x0=0, 

105 y0=max_value, 

106 x1=max_value, 

107 y1=max_value, 

108 line=dict(color="red", width=1.5, dash="longdash"), 

109 ) 

110 fig.add_shape( 

111 type="line", 

112 opacity=0.7, 

113 x0=max_value, 

114 y0=0, 

115 x1=max_value, 

116 y1=max_value, 

117 line=dict(color="red", width=0.5, dash="longdash"), 

118 ) 

119 fig.update_traces(marker=dict(color="RoyalBlue", symbol="x")) 

120 fig.update_layout(plot_bgcolor="white", autosize=False, width=1000, height=1000) 

121 minor = dict(ticks="inside", ticklen=6, showgrid=True) if log_scale else None 

122 # Tick every 10^(log(max)) / 10 for linear scale, log scale is resolved by plotly 

123 dtick = 1 if log_scale else 10 ** (math.ceil(math.log(max_value, 10)) - 1) 

124 fig.update_xaxes( 

125 mirror=True, 

126 tickmode="linear", 

127 ticks="outside", 

128 tick0=0, 

129 minor=minor, 

130 dtick=dtick, 

131 showline=True, 

132 linecolor="black", 

133 gridcolor="lightgrey", 

134 ) 

135 fig.update_yaxes( 

136 mirror=True, 

137 tickmode="linear", 

138 ticks="outside", 

139 tick0=0, 

140 minor=minor, 

141 dtick=dtick, 

142 showline=True, 

143 linecolor="black", 

144 gridcolor="lightgrey", 

145 ) 

146 return fig