Coverage for sparkle/solver/selector.py: 34%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-27 09:10 +0000

1"""File to handle a Selector for selecting Solvers.""" 

2from __future__ import annotations 

3from pathlib import Path 

4import subprocess 

5import ast 

6 

7import runrunner as rrr 

8from runrunner import Runner, Run 

9 

10from sparkle.types import SparkleCallable, SparkleObjective 

11from sparkle.structures import FeatureDataFrame, PerformanceDataFrame 

12 

13 

14class Selector(SparkleCallable): 

15 """The Selector class for handling Algorithm Selection.""" 

16 

17 def __init__(self: SparkleCallable, 

18 executable_path: Path, 

19 raw_output_directory: Path) -> None: 

20 """Initialize the Selector object. 

21 

22 Args: 

23 executable_path: Path of the Selector executable. 

24 raw_output_directory: Directory where the Selector will write its raw output. 

25 Defaults to directory / tmp 

26 """ 

27 self.selector_builder_path = executable_path 

28 self.directory = self.selector_builder_path.parent 

29 self.name = self.selector_builder_path.name 

30 self.raw_output_directory = raw_output_directory 

31 

32 if not self.raw_output_directory.exists(): 

33 self.raw_output_directory.mkdir(parents=True) 

34 

35 def build_construction_cmd( 

36 self: Selector, 

37 target_file: Path, 

38 performance_data: Path, 

39 feature_data: Path, 

40 objective: SparkleObjective, 

41 runtime_cutoff: int | float | str = None, 

42 wallclock_limit: int | float | str = None) -> list[str | Path]: 

43 """Builds the commandline call string for constructing the Selector. 

44 

45 Args: 

46 target_file: Path to the file to save the Selector to. 

47 performance_data: Path to the performance data csv. 

48 feature_data: Path to the feature data csv. 

49 objective: The objective to optimize for selection. 

50 runtime_cutoff: Cutoff for the runtime in seconds. Defaults to None 

51 wallclock_limit: Cutoff for total wallclock in seconds. Defaults to None 

52 

53 Returns: 

54 The command list for constructing the Selector. 

55 """ 

56 objective_function = "runtime" if objective.time else "solution_quality" 

57 # Python3 to avoid execution rights 

58 cmd = ["python3", self.selector_builder_path, 

59 "--performance_csv", performance_data, 

60 "--feature_csv", feature_data, 

61 "--objective", objective_function, 

62 "--save", target_file] 

63 if runtime_cutoff is not None: 

64 cmd.extend(["--runtime_cutoff", str(runtime_cutoff), "--tune"]) 

65 if wallclock_limit is not None: 

66 cmd.extend(["--wallclock_limit", str(wallclock_limit)]) 

67 return cmd 

68 

69 def construct(self: Selector, 

70 target_file: Path | str, 

71 performance_data: PerformanceDataFrame, 

72 feature_data: FeatureDataFrame, 

73 objective: SparkleObjective, 

74 runtime_cutoff: int | float | str = None, 

75 wallclock_limit: int | float | str = None, 

76 run_on: Runner = Runner.SLURM, 

77 sbatch_options: list[str] = None, 

78 base_dir: Path = Path()) -> Run: 

79 """Construct the Selector. 

80 

81 Args: 

82 target_file: Path to the file to save the Selector to. 

83 performance_data: Path to the performance data csv. 

84 feature_data: Path to the feature data csv. 

85 objective: The objective to optimize for selection. 

86 runtime_cutoff: Cutoff for the runtime in seconds. 

87 wallclock_limit: Cutoff for the wallclock time in seconds. 

88 run_on: Which runner to use. Defaults to slurm. 

89 sbatch_options: Additional options to pass to sbatch. 

90 base_dir: The base directory to run the Selector in. 

91 

92 Returns: 

93 Path to the constructed Selector. 

94 """ 

95 if isinstance(target_file, str): 

96 target_file = self.raw_output_directory / target_file 

97 # Convert the dataframes to Selector Format 

98 performance_csv = performance_data.to_autofolio(objective=objective, 

99 target=target_file.parent) 

100 feature_csv = feature_data.to_autofolio(target_file.parent) 

101 cmd = self.build_construction_cmd(target_file, 

102 performance_csv, 

103 feature_csv, 

104 objective, 

105 runtime_cutoff, 

106 wallclock_limit) 

107 

108 cmd_str = " ".join([str(c) for c in cmd]) 

109 construct = rrr.add_to_queue( 

110 runner=run_on, 

111 cmd=[cmd_str], 

112 name="construct_selector", 

113 base_dir=base_dir, 

114 stdout=Path("normal.log"), 

115 stderr=Path("error.log"), 

116 sbatch_options=sbatch_options) 

117 if run_on == Runner.LOCAL: 

118 construct.wait() 

119 if not target_file.is_file(): 

120 print(f"Selector construction of {self.name} failed!") 

121 

122 return construct 

123 

124 def build_cmd(self: Selector, 

125 selector_path: Path, 

126 feature_vector: list | str) -> list[str | Path]: 

127 """Builds the commandline call string for running the Selector.""" 

128 if isinstance(feature_vector, list): 

129 feature_vector = " ".join(map(str, feature_vector)) 

130 

131 return ["python3", self.selector_builder_path, 

132 "--load", selector_path, 

133 "--feature_vec", feature_vector] 

134 

135 def run(self: Selector, 

136 selector_path: Path, 

137 feature_vector: list | str) -> list: 

138 """Run the Selector, returning the prediction schedule upon success.""" 

139 cmd = self.build_cmd(selector_path, feature_vector) 

140 run = subprocess.run(cmd, capture_output=True) 

141 if run.returncode != 0: 

142 print(f"Selector run of {self.name} failed! Error:\n" 

143 f"{run.stderr.decode()}") 

144 return None 

145 # Process the prediction schedule from the output 

146 schedule = Selector.process_predict_schedule_output(run.stdout.decode()) 

147 if schedule is None: 

148 print(f"Error getting predict schedule! Selector {self.name} output:\n" 

149 f"{run.stderr.decode()}") 

150 return schedule 

151 

152 @staticmethod 

153 def process_predict_schedule_output(output: str) -> list: 

154 """Return the predicted algorithm schedule as a list.""" 

155 prefix_string = "Selected Schedule [(algorithm, budget)]: " 

156 predict_schedule = "" 

157 predict_schedule_lines = output.splitlines() 

158 for line in predict_schedule_lines: 

159 if line.strip().startswith(prefix_string): 

160 predict_schedule = line.strip() 

161 break 

162 if predict_schedule == "": 

163 return None 

164 predict_schedule_string = predict_schedule[len(prefix_string):] 

165 return ast.literal_eval(predict_schedule_string)