Coverage for sparkle/configurator/ablation.py: 91%

111 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 10:34 +0000

1#!/usr/bin/env python3 

2# -*- coding: UTF-8 -*- 

3"""Helper functions for ablation analysis.""" 

4from __future__ import annotations 

5import re 

6import shutil 

7import decimal 

8from pathlib import Path 

9 

10import runrunner as rrr 

11from runrunner.base import Runner, Run 

12 

13from sparkle.configurator import ConfigurationScenario 

14from sparkle.instance import InstanceSet 

15 

16 

17class AblationScenario: 

18 """Class for ablation analysis.""" 

19 

20 # We use the SMAC2 target algorithm for solver output handling 

21 configurator_target = Path(__file__).parent.parent.resolve() /\ 

22 "Components" / "smac2-v2.10.03-master-778" / "smac2_target_algorithm.py" 

23 

24 ablation_dir = Path(__file__).parent.parent / "Components" /\ 

25 "ablationAnalysis-0.9.4" 

26 ablation_executable = ablation_dir / "ablationAnalysis" 

27 ablation_validation_executable = ablation_dir / "ablationValidation" 

28 

29 def __init__(self: AblationScenario, 

30 configuration_scenario: ConfigurationScenario, 

31 test_set: InstanceSet, 

32 output_dir: Path, 

33 override_dirs: bool = False) -> None: 

34 """Initialize ablation scenario. 

35 

36 Args: 

37 solver: Solver object 

38 configuration_scenario: Configuration scenario 

39 train_set: The training instance 

40 test_set: The test instance 

41 output_dir: The output directory 

42 override_dirs: Whether to clean the scenario directory if it already exists 

43 """ 

44 self.config_scenario = configuration_scenario 

45 self.solver = configuration_scenario.solver 

46 self.train_set = configuration_scenario.instance_set 

47 self.concurrent_clis = None 

48 self.test_set = test_set 

49 self.output_dir = output_dir 

50 self.scenario_name = configuration_scenario.name 

51 if self.test_set is not None: 

52 self.scenario_name += f"_{self.test_set.name}" 

53 self.scenario_dir = self.output_dir / self.scenario_name 

54 if override_dirs and self.scenario_dir.exists(): 

55 print("Warning: found existing ablation scenario. This will be removed.") 

56 shutil.rmtree(self.scenario_dir) 

57 

58 # Create required scenario directories 

59 self.tmp_dir = self.scenario_dir / "tmp" 

60 self.tmp_dir.mkdir(parents=True, exist_ok=True) 

61 

62 self.validation_dir = self.scenario_dir / "validation" 

63 self.validation_dir_tmp = self.validation_dir / "tmp" 

64 self.validation_dir_tmp.mkdir(parents=True, exist_ok=True) 

65 self.table_file = self.validation_dir / "log" / "ablation-validation-run1234.txt" 

66 

67 def create_configuration_file(self: AblationScenario, 

68 cutoff_time: int, 

69 cutoff_length: str, 

70 concurrent_clis: int, 

71 best_configuration: dict, 

72 ablation_racing: bool = False) -> Path: 

73 """Create a configuration file for ablation analysis. 

74 

75 Args: 

76 cutoff_time: The cutoff time for ablation analysis 

77 cutoff_length: The cutoff length for ablation analysis 

78 concurrent_clis: The maximum number of concurrent jobs on a single node 

79 

80 Returns: 

81 None 

82 """ 

83 self.concurrent_clis = concurrent_clis 

84 ablation_scenario_dir = self.scenario_dir 

85 objective = self.config_scenario.sparkle_objective 

86 pcs = self.solver.get_cs() 

87 parameter_names = [p.name for p in pcs.values()] 

88 # We need to remove any redundant keys that are not in PCS 

89 removable_keys = [key for key in best_configuration 

90 if key not in parameter_names] 

91 for key in removable_keys: 

92 del best_configuration[key] 

93 opt_config_str = " ".join([f"-{k} {v}" for k, v in best_configuration.items()]) 

94 # We need to check which params are missing and supplement with default values 

95 for p in list(pcs.values()): 

96 if p.name not in opt_config_str: 

97 opt_config_str += f" -{p.name} {p.default_value}" 

98 

99 # Ablation cannot deal with E scientific notation in floats 

100 ctx = decimal.Context(prec=16) 

101 for config in opt_config_str.split(" -"): 

102 _, value = config.strip().split(" ") 

103 if "e" in value.lower(): 

104 value = value.strip("'") 

105 float_value = float(value.lower()) 

106 formatted = format(ctx.create_decimal(float_value), "f") 

107 opt_config_str = opt_config_str.replace(value, formatted) 

108 

109 smac_run_obj = "RUNTIME" if objective.time else "QUALITY" 

110 objective_str = "MEAN10" if objective.time else "MEAN" 

111 pcs_file_path = f"{self.config_scenario.solver.pcs_file.absolute()}" 

112 

113 # Create config file 

114 config_file = Path(f"{ablation_scenario_dir}/ablation_config.txt") 

115 config = (f'algo = "{AblationScenario.configurator_target.absolute()} ' 

116 f"{self.config_scenario.solver.directory.absolute()} " 

117 f'{self.tmp_dir.absolute()} {objective}"\n' 

118 f"execdir = {self.tmp_dir.absolute()}\n" 

119 "experimentDir = ./\n" 

120 f"deterministic = {1 if self.solver.deterministic else 0}\n" 

121 f"run_obj = {smac_run_obj}\n" 

122 f"overall_obj = {objective_str}\n" 

123 f"cutoffTime = {cutoff_time}\n" 

124 f"cutoff_length = {cutoff_length}\n" 

125 f"cli-cores = {self.concurrent_clis}\n" 

126 f"useRacing = {ablation_racing}\n" 

127 "seed = 1234\n" 

128 f"paramfile = {pcs_file_path}\n" 

129 "instance_file = instances_train.txt\n" 

130 "test_instance_file = instances_test.txt\n" 

131 "sourceConfiguration=DEFAULT\n" 

132 f'targetConfiguration="{opt_config_str}"') 

133 config_file.open("w").write(config) 

134 # Write config to validation directory 

135 conf_valid = config.replace(f"execdir = {self.tmp_dir.absolute()}\n", 

136 f"execdir = {self.validation_dir_tmp.absolute()}\n") 

137 (self.validation_dir / config_file.name).open("w").write(conf_valid) 

138 return self.validation_dir / config_file.name 

139 

140 def create_instance_file(self: AblationScenario, test: bool = False) -> Path: 

141 """Create an instance file for ablation analysis.""" 

142 file_suffix = "_train.txt" 

143 instance_set = self.train_set 

144 if test: 

145 file_suffix = "_test.txt" 

146 instance_set = self.test_set if self.test_set is not None else self.train_set 

147 # We give the Ablation script the paths of the instances 

148 file_instance = self.scenario_dir / f"instances{file_suffix}" 

149 with file_instance.open("w") as fh: 

150 for instance in instance_set._instance_paths: 

151 # We need to unpack the multi instance file paths in quotes 

152 if isinstance(instance, list): 

153 joined_instances = " ".join( 

154 [str(file.absolute()) for file in instance]) 

155 fh.write(f"{joined_instances}\n") 

156 else: 

157 fh.write(f"{instance.absolute()}\n") 

158 # Copy to validation directory 

159 shutil.copyfile(file_instance, self.validation_dir / file_instance.name) 

160 return file_instance 

161 

162 def check_for_ablation(self: AblationScenario) -> bool: 

163 """Checks if ablation has terminated successfully.""" 

164 if not self.table_file.is_file(): 

165 return False 

166 # First line in the table file should be "Ablation analysis validation complete." 

167 table_line = self.table_file.open().readline().strip() 

168 return table_line == "Ablation analysis validation complete." 

169 

170 def read_ablation_table(self: AblationScenario) -> list[list[str]]: 

171 """Read from ablation table of a scenario.""" 

172 if not self.check_for_ablation(): 

173 # No ablation table exists for this solver-instance pair 

174 return [] 

175 results = [["Round", "Flipped parameter", "Source value", "Target value", 

176 "Validation result"]] 

177 

178 for line in self.table_file.open().readlines(): 

179 # Pre-process lines from the ablation file and add to the results dictionary. 

180 # Sometimes ablation rounds switch multiple parameters at once. 

181 # EXAMPLE: 2 EDR, EDRalpha 0, 0.1 1, 0.1013241633106732 486.31691 

182 # To split the row correctly, we remove the space before the comma separated 

183 # parameters and add it back. 

184 # T.S. 30-01-2024: the results object is a nested list not dictionary? 

185 values = re.sub(r"\s+", " ", line.strip()) 

186 values = re.sub(r", ", ",", values) 

187 values = [val.replace(",", ", ") for val in values.split(" ")] 

188 if len(values) == 5: 

189 results.append(values) 

190 return results 

191 

192 def submit_ablation(self: AblationScenario, 

193 log_dir: Path, 

194 sbatch_options: list[str] = [], 

195 slurm_prepend: str | list[str] | Path = None, 

196 run_on: Runner = Runner.SLURM) -> list[Run]: 

197 """Submit an ablation job. 

198 

199 Args: 

200 log_dir: Directory to store job logs 

201 sbatch_options: Options to pass to sbatch 

202 slurm_prepend: Script to prepend to sbatch script 

203 run_on: Determines to which RunRunner queue the job is added 

204 

205 Returns: 

206 A list of Run objects. Empty when running locally. 

207 """ 

208 # 1. submit the ablation to the runrunner queue 

209 cmd = (f"{AblationScenario.ablation_executable.absolute()} " 

210 "--optionFile ablation_config.txt") 

211 srun_options = ["-N1", "-n1", f"-c{self.concurrent_clis}"] 

212 sbatch_options += [f"--cpus-per-task={self.concurrent_clis}"] 

213 run_ablation = rrr.add_to_queue( 

214 runner=run_on, 

215 cmd=cmd, 

216 name=f"Ablation analysis: {self.solver.name} on {self.train_set.name}", 

217 base_dir=log_dir, 

218 path=self.scenario_dir, 

219 sbatch_options=sbatch_options, 

220 srun_options=srun_options, 

221 prepend=slurm_prepend) 

222 

223 runs = [] 

224 if run_on == Runner.LOCAL: 

225 run_ablation.wait() 

226 runs.append(run_ablation) 

227 

228 # 2. Run ablation validation run if we have a test set to run on 

229 if self.test_set is not None: 

230 # Validation dir should have a copy of all needed files, except for the 

231 # output of the ablation run, which is stored in ablation-run[seed].txt 

232 cmd = f"{AblationScenario.ablation_validation_executable.absolute()} "\ 

233 "--optionFile ablation_config.txt "\ 

234 "--ablationLogFile ../log/ablation-run1234.txt" 

235 

236 run_ablation_validation = rrr.add_to_queue( 

237 runner=run_on, 

238 cmd=cmd, 

239 name=f"Ablation validation: Test set {self.test_set.name}", 

240 path=self.validation_dir, 

241 base_dir=log_dir, 

242 dependencies=run_ablation, 

243 sbatch_options=sbatch_options, 

244 prepend=slurm_prepend) 

245 

246 if run_on == Runner.LOCAL: 

247 run_ablation_validation.wait() 

248 runs.append(run_ablation_validation) 

249 return runs