Coverage for sparkle/solver/ablation.py: 82%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1#!/usr/bin/env python3 

2# -*- coding: UTF-8 -*- 

3"""Helper functions for ablation analysis.""" 

4from __future__ import annotations 

5import re 

6import shutil 

7import decimal 

8from pathlib import Path 

9 

10import runrunner as rrr 

11from runrunner.base import Runner, Run 

12 

13from sparkle.CLI.help import global_variables as gv 

14from sparkle.CLI.help import logging as sl 

15 

16from sparkle.configurator.implementations import SMAC2 

17from sparkle.platform import CommandName 

18from sparkle.solver import Solver 

19from sparkle.instance import InstanceSet 

20 

21 

22class AblationScenario: 

23 """Class for ablation analysis.""" 

24 def __init__(self: AblationScenario, 

25 solver: Solver, 

26 train_set: InstanceSet, 

27 test_set: InstanceSet, 

28 output_dir: Path, 

29 ablation_executable: Path = None, 

30 ablation_validation_executable: Path = None, 

31 override_dirs: bool = False) -> None: 

32 """Initialize ablation scenario. 

33 

34 Args: 

35 solver: Solver object 

36 train_set: The training instance 

37 test_set: The test instance 

38 output_dir: The output directory 

39 ablation_executable: (Only for execution) The ablation executable 

40 ablation_validation_executable: (Only for execution) The validation exec 

41 override_dirs: Whether to clean the scenario directory if it already exists 

42 """ 

43 self.ablation_exec = ablation_executable 

44 self.ablation_validation_exec = ablation_validation_executable 

45 self.solver = solver 

46 self.train_set = train_set 

47 self.test_set = test_set 

48 self.output_dir = output_dir 

49 self.scenario_name = f"{self.solver.name}_{self.train_set.name}" 

50 if self.test_set is not None: 

51 self.scenario_name += f"_{self.test_set.name}" 

52 self.scenario_dir = self.output_dir / self.scenario_name 

53 if override_dirs and self.scenario_dir.exists(): 

54 print("Warning: found existing ablation scenario. This will be removed.") 

55 shutil.rmtree(self.scenario_dir) 

56 

57 # Create required scenario directories 

58 self.tmp_dir = self.scenario_dir / "tmp" 

59 self.tmp_dir.mkdir(parents=True, exist_ok=True) 

60 

61 self.validation_dir = self.scenario_dir / "validation" 

62 self.validation_dir_tmp = self.validation_dir / "tmp" 

63 self.validation_dir_tmp.mkdir(parents=True, exist_ok=True) 

64 self.table_file = self.validation_dir / "log" / "ablation-validation-run1234.txt" 

65 

66 def create_configuration_file(self: AblationScenario) -> None: 

67 """Create a configuration file for ablation analysis. 

68 

69 Args: 

70 solver: Solver object 

71 instance_train_name: The training instance 

72 instance_test_name: The test instance 

73 

74 Returns: 

75 None 

76 """ 

77 ablation_scenario_dir = self.scenario_dir 

78 objective = gv.settings().get_general_sparkle_objectives()[0] 

79 configurator = gv.settings().get_general_sparkle_configurator() 

80 config_scenario = gv.latest_scenario().get_configuration_scenario( 

81 configurator.scenario_class) 

82 _, opt_config_str = configurator.get_optimal_configuration( 

83 config_scenario) 

84 

85 # We need to check which params are missing and supplement with default values 

86 pcs = self.solver.get_pcs() 

87 for p in pcs: 

88 if p["name"] not in opt_config_str: 

89 opt_config_str += f" -{p['name']} {p['default']}" 

90 

91 # Ablation cannot deal with E scientific notation in floats 

92 ctx = decimal.Context(prec=16) 

93 for config in opt_config_str.split(" -"): 

94 _, value = config.strip().split(" ") 

95 if "e" in value.lower(): 

96 value = value.strip("'") 

97 float_value = float(value.lower()) 

98 formatted = format(ctx.create_decimal(float_value), "f") 

99 opt_config_str = opt_config_str.replace(value, formatted) 

100 

101 smac_run_obj = SMAC2.get_smac_run_obj(objective) 

102 objective_str = "MEAN10" if smac_run_obj == "RUNTIME" else "MEAN" 

103 run_cutoff_time = gv.settings().get_general_target_cutoff_time() 

104 run_cutoff_length = gv.settings().get_smac2_target_cutoff_length() 

105 concurrent_clis = gv.settings().get_slurm_max_parallel_runs_per_node() 

106 ablation_racing = gv.settings().get_ablation_racing_flag() 

107 configurator = gv.settings().get_general_sparkle_configurator() 

108 pcs_file_path = f"{self.solver.get_pcs_file().absolute()}" # Get Solver PCS 

109 

110 # Create config file 

111 config_file = Path(f"{ablation_scenario_dir}/ablation_config.txt") 

112 config = (f'algo = "{SMAC2.configurator_target.absolute()} ' 

113 f'{self.solver.directory.absolute()} {objective}"\n' 

114 f"execdir = {self.tmp_dir.absolute()}\n" 

115 "experimentDir = ./\n" 

116 f"deterministic = {1 if self.solver.deterministic else 0}\n" 

117 f"run_obj = {smac_run_obj}\n" 

118 f"overall_obj = {objective_str}\n" 

119 f"cutoffTime = {run_cutoff_time}\n" 

120 f"cutoff_length = {run_cutoff_length}\n" 

121 f"cli-cores = {concurrent_clis}\n" 

122 f"useRacing = {ablation_racing}\n" 

123 "seed = 1234\n" 

124 f"paramfile = {pcs_file_path}\n" 

125 "instance_file = instances_train.txt\n" 

126 "test_instance_file = instances_test.txt\n" 

127 "sourceConfiguration=DEFAULT\n" 

128 f'targetConfiguration="{opt_config_str}"') 

129 config_file.open("w").write(config) 

130 # Write config to validation directory 

131 conf_valid = config.replace(f"execdir = {self.tmp_dir.absolute()}\n", 

132 f"execdir = {self.validation_dir_tmp.absolute()}\n") 

133 (self.validation_dir / config_file.name).open("w").write(conf_valid) 

134 

135 def create_instance_file(self: AblationScenario, test: bool = False) -> None: 

136 """Create an instance file for ablation analysis.""" 

137 file_suffix = "_train.txt" 

138 instance_set = self.train_set 

139 if test: 

140 file_suffix = "_test.txt" 

141 instance_set = self.test_set if self.test_set is not None else self.train_set 

142 # We give the Ablation script the paths of the instances 

143 file_instance = self.scenario_dir / f"instances{file_suffix}" 

144 with file_instance.open("w") as fh: 

145 for instance in instance_set._instance_paths: 

146 # We need to unpack the multi instance file paths in quotes 

147 if isinstance(instance, list): 

148 joined_instances = " ".join( 

149 [str(file.absolute()) for file in instance]) 

150 fh.write(f"{joined_instances}\n") 

151 else: 

152 fh.write(f"{instance.absolute()}\n") 

153 # Copy to validation directory 

154 shutil.copyfile(file_instance, self.validation_dir / file_instance.name) 

155 

156 def check_for_ablation(self: AblationScenario) -> bool: 

157 """Checks if ablation has terminated successfully.""" 

158 if not self.table_file.is_file(): 

159 return False 

160 # First line in the table file should be "Ablation analysis validation complete." 

161 table_line = self.table_file.open().readline().strip() 

162 return table_line == "Ablation analysis validation complete." 

163 

164 def read_ablation_table(self: AblationScenario) -> list[list[str]]: 

165 """Read from ablation table of a scenario.""" 

166 if not self.check_for_ablation(): 

167 # No ablation table exists for this solver-instance pair 

168 return [] 

169 results = [["Round", "Flipped parameter", "Source value", "Target value", 

170 "Validation result"]] 

171 

172 for line in self.table_file.open().readlines(): 

173 # Pre-process lines from the ablation file and add to the results dictionary. 

174 # Sometimes ablation rounds switch multiple parameters at once. 

175 # EXAMPLE: 2 EDR, EDRalpha 0, 0.1 1, 0.1013241633106732 486.31691 

176 # To split the row correctly, we remove the space before the comma separated 

177 # parameters and add it back. 

178 # T.S. 30-01-2024: the results object is a nested list not dictionary? 

179 values = re.sub(r"\s+", " ", line.strip()) 

180 values = re.sub(r", ", ",", values) 

181 values = [val.replace(",", ", ") for val in values.split(" ")] 

182 if len(values) == 5: 

183 results.append(values) 

184 return results 

185 

186 def submit_ablation(self: AblationScenario, 

187 run_on: Runner = Runner.SLURM) -> list[Run]: 

188 """Submit an ablation job. 

189 

190 Args: 

191 run_on: Determines to which RunRunner queue the job is added 

192 

193 Returns: 

194 A list of Run objects. Empty when running locally. 

195 """ 

196 # 1. submit the ablation to the runrunner queue 

197 clis = gv.settings().get_slurm_max_parallel_runs_per_node() 

198 cmd = f"{self.ablation_exec.absolute()} --optionFile ablation_config.txt" 

199 srun_options = ["-N1", "-n1", f"-c{clis}"] 

200 sbatch_options = [f"--cpus-per-task={clis}"] +\ 

201 gv.settings().get_slurm_extra_options(as_args=True) 

202 

203 run_ablation = rrr.add_to_queue( 

204 runner=run_on, 

205 cmd=cmd, 

206 name=CommandName.RUN_ABLATION, 

207 base_dir=sl.caller_log_dir, 

208 path=self.scenario_dir, 

209 sbatch_options=sbatch_options, 

210 srun_options=srun_options) 

211 

212 runs = [] 

213 if run_on == Runner.LOCAL: 

214 run_ablation.wait() 

215 runs.append(run_ablation) 

216 

217 # 2. Run ablation validation run if we have a test set to run on 

218 if self.test_set is not None: 

219 # Validation dir should have a copy of all needed files, except for the 

220 # output of the ablation run, which is stored in ablation-run[seed].txt 

221 cmd = f"{self.ablation_validation_exec.absolute()} "\ 

222 "--optionFile ablation_config.txt "\ 

223 "--ablationLogFile ../log/ablation-run1234.txt" 

224 

225 run_ablation_validation = rrr.add_to_queue( 

226 runner=run_on, 

227 cmd=cmd, 

228 name=CommandName.RUN_ABLATION_VALIDATION, 

229 path=self.validation_dir, 

230 base_dir=sl.caller_log_dir, 

231 dependencies=run_ablation, 

232 sbatch_options=sbatch_options, 

233 srun_options=srun_options) 

234 

235 if run_on == Runner.LOCAL: 

236 run_ablation_validation.wait() 

237 runs.append(run_ablation_validation) 

238 

239 return runs