Coverage for sparkle/solver/solver.py: 85%

155 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1"""File to handle a solver and its directories.""" 

2 

3from __future__ import annotations 

4import sys 

5from typing import Any 

6import shlex 

7import ast 

8import json 

9from pathlib import Path 

10 

11import runrunner as rrr 

12from runrunner.local import LocalRun 

13from runrunner.slurm import SlurmRun 

14from runrunner.base import Status, Runner 

15 

16from sparkle.tools import pcsparser, RunSolver 

17from sparkle.types import SparkleCallable, SolverStatus 

18from sparkle.solver.verifier import SolutionVerifier 

19from sparkle.instance import InstanceSet 

20from sparkle.types import resolve_objective, SparkleObjective, UseTime 

21 

22 

23class Solver(SparkleCallable): 

24 """Class to handle a solver and its directories.""" 

25 meta_data = "solver_meta.txt" 

26 wrapper = "sparkle_solver_wrapper.py" 

27 

28 def __init__(self: Solver, 

29 directory: Path, 

30 raw_output_directory: Path = None, 

31 runsolver_exec: Path = None, 

32 deterministic: bool = None, 

33 verifier: SolutionVerifier = None) -> None: 

34 """Initialize solver. 

35 

36 Args: 

37 directory: Directory of the solver. 

38 raw_output_directory: Directory where solver will write its raw output. 

39 runsolver_exec: Path to the runsolver executable. 

40 By default, runsolver in directory. 

41 deterministic: Bool indicating determinism of the algorithm. 

42 Defaults to False. 

43 verifier: The solution verifier to use. If None, no verifier is used. 

44 """ 

45 super().__init__(directory, runsolver_exec, raw_output_directory) 

46 self.deterministic = deterministic 

47 self.verifier = verifier 

48 self.meta_data_file = self.directory / Solver.meta_data 

49 

50 if self.runsolver_exec is None: 

51 self.runsolver_exec = self.directory / "runsolver" 

52 if not self.meta_data_file.exists(): 

53 self.meta_data_file = None 

54 if self.deterministic is None: 

55 if self.meta_data_file is not None: 

56 # Read the parameter from file 

57 meta_dict = ast.literal_eval(self.meta_data_file.open().read()) 

58 self.deterministic = meta_dict["deterministic"] 

59 else: 

60 self.deterministic = False 

61 

62 def _get_pcs_file(self: Solver, port_type: str = None) -> Path | bool: 

63 """Get path of the parameter file. 

64 

65 Returns: 

66 Path to the parameter file or False if the parameter file does not exist. 

67 """ 

68 pcs_files = [p for p in self.directory.iterdir() if p.suffix == ".pcs" 

69 and (port_type is None or port_type in p.name)] 

70 

71 if len(pcs_files) == 0: 

72 return False 

73 if len(pcs_files) != 1: 

74 # Generated PCS files present, this is a quick fix to take the original 

75 pcs_files = sorted(pcs_files, key=lambda p: len(p.name)) 

76 return pcs_files[0] 

77 

78 def get_pcs_file(self: Solver, port_type: str = None) -> Path: 

79 """Get path of the parameter file. 

80 

81 Returns: 

82 Path to the parameter file. None if it can not be resolved. 

83 """ 

84 if not (file_path := self._get_pcs_file(port_type)): 

85 return None 

86 return file_path 

87 

88 def read_pcs_file(self: Solver) -> bool: 

89 """Checks if the pcs file can be read.""" 

90 pcs_file = self._get_pcs_file() 

91 try: 

92 parser = pcsparser.PCSParser() 

93 parser.load(str(pcs_file), convention="smac") 

94 return True 

95 except SyntaxError: 

96 pass 

97 return False 

98 

99 def get_pcs(self: Solver) -> dict[str, tuple[str, str, str]]: 

100 """Get the parameter content of the PCS file.""" 

101 if not (pcs_file := self.get_pcs_file()): 

102 return None 

103 parser = pcsparser.PCSParser() 

104 parser.load(str(pcs_file), convention="smac") 

105 return [p for p in parser.pcs.params if p["type"] == "parameter"] 

106 

107 def port_pcs(self: Solver, port_type: pcsparser.PCSConvention) -> None: 

108 """Port the parameter file to the given port type.""" 

109 pcs_file = self.get_pcs_file() 

110 parser = pcsparser.PCSParser() 

111 parser.load(str(pcs_file), convention="smac") 

112 target_pcs_file = pcs_file.parent / f"{pcs_file.stem}_{port_type}.pcs" 

113 if target_pcs_file.exists(): # Already exists, possibly user defined 

114 return 

115 parser.export(convention=port_type, 

116 destination=target_pcs_file) 

117 

118 def get_forbidden(self: Solver, port_type: pcsparser.PCSConvention) -> Path: 

119 """Get the path to the file containing forbidden parameter combinations.""" 

120 if port_type == "IRACE": 

121 forbidden = [p for p in self.directory.iterdir() 

122 if p.name.endswith("forbidden.txt")] 

123 if len(forbidden) > 0: 

124 return forbidden[0] 

125 return None 

126 

127 def build_cmd(self: Solver, 

128 instance: str | list[str], 

129 objectives: list[SparkleObjective], 

130 seed: int, 

131 cutoff_time: int = None, 

132 configuration: dict = None, 

133 log_dir: Path = None) -> list[str]: 

134 """Build the solver call on an instance with a configuration. 

135 

136 Args: 

137 instance: Path to the instance. 

138 seed: Seed of the solver. 

139 cutoff_time: Cutoff time for the solver. 

140 configuration: Configuration of the solver. 

141 

142 Returns: 

143 List of commands and arguments to execute the solver. 

144 """ 

145 if configuration is None: 

146 configuration = {} 

147 # Ensure configuration contains required entries for each wrapper 

148 configuration["solver_dir"] = str(self.directory.absolute()) 

149 configuration["instance"] = instance 

150 configuration["seed"] = seed 

151 configuration["objectives"] = ",".join([str(obj) for obj in objectives]) 

152 configuration["cutoff_time"] =\ 

153 cutoff_time if cutoff_time is not None else sys.maxsize 

154 # Ensure stringification of dictionary will go correctly for key value pairs 

155 configuration = {key: str(configuration[key]) for key in configuration} 

156 solver_cmd = [str((self.directory / Solver.wrapper)), 

157 f"'{json.dumps(configuration)}'"] 

158 if log_dir is None: 

159 log_dir = Path() 

160 if cutoff_time is not None: # Use RunSolver 

161 log_name_base = f"{Path(instance).name}_{self.name}" 

162 return RunSolver.wrap_command(self.runsolver_exec, 

163 solver_cmd, 

164 cutoff_time, 

165 log_dir, 

166 log_name_base=log_name_base) 

167 return solver_cmd 

168 

169 def run(self: Solver, 

170 instance: str | list[str] | InstanceSet, 

171 objectives: list[SparkleObjective], 

172 seed: int, 

173 cutoff_time: int = None, 

174 configuration: dict = None, 

175 run_on: Runner = Runner.LOCAL, 

176 commandname: str = "run_solver", 

177 sbatch_options: list[str] = None, 

178 log_dir: Path = None) -> SlurmRun | list[dict[str, Any]] | dict[str, Any]: 

179 """Run the solver on an instance with a certain configuration. 

180 

181 Args: 

182 instance: The instance(s) to run the solver on, list in case of multi-file. 

183 In case of an instance set, will run on all instances in the set. 

184 seed: Seed to run the solver with. Fill with abitrary int in case of 

185 determnistic solver. 

186 cutoff_time: The cutoff time for the solver, measured through RunSolver. 

187 If None, will be executed without RunSolver. 

188 configuration: The solver configuration to use. Can be empty. 

189 log_dir: Path where to place output files. Defaults to 

190 self.raw_output_directory. 

191 

192 Returns: 

193 Solver output dict possibly with runsolver values. 

194 """ 

195 if log_dir is None: 

196 log_dir = self.raw_output_directory 

197 cmds = [] 

198 if isinstance(instance, InstanceSet): 

199 for inst in instance.instance_paths: 

200 solver_cmd = self.build_cmd(inst.absolute(), 

201 objectives=objectives, 

202 seed=seed, 

203 cutoff_time=cutoff_time, 

204 configuration=configuration, 

205 log_dir=log_dir) 

206 cmds.append(" ".join(solver_cmd)) 

207 else: 

208 solver_cmd = self.build_cmd(instance, 

209 objectives=objectives, 

210 seed=seed, 

211 cutoff_time=cutoff_time, 

212 configuration=configuration, 

213 log_dir=log_dir) 

214 cmds.append(" ".join(solver_cmd)) 

215 run = rrr.add_to_queue(runner=run_on, 

216 cmd=cmds, 

217 name=commandname, 

218 base_dir=log_dir, 

219 sbatch_options=sbatch_options) 

220 

221 if isinstance(run, LocalRun): 

222 run.wait() 

223 # Subprocess resulted in error 

224 if run.status == Status.ERROR: 

225 print(f"WARNING: Solver {self.name} execution seems to have failed!\n") 

226 for i, job in enumerate(run.jobs): 

227 print(f"[Job {i}] The used command was: {cmds[i]}\n" 

228 "The error yielded was:\n" 

229 f"\t-stdout: '{run.jobs[0]._process.stdout}'\n" 

230 f"\t-stderr: '{run.jobs[0]._process.stderr}'\n") 

231 return {"status": SolverStatus.ERROR, } 

232 

233 solver_outputs = [] 

234 for i, job in enumerate(run.jobs): 

235 solver_cmd = cmds[i].split(" ") 

236 runsolver_configuration = None 

237 if solver_cmd[0] == str(self.runsolver_exec.absolute()): 

238 runsolver_configuration = solver_cmd[:11] 

239 solver_output = Solver.parse_solver_output(run.jobs[i].stdout, 

240 runsolver_configuration) 

241 if self.verifier is not None: 

242 solver_output["status"] = self.verifier.verifiy( 

243 instance, Path(runsolver_configuration[-1])) 

244 solver_outputs.append(solver_output) 

245 return solver_outputs if len(solver_outputs) > 1 else solver_output 

246 return run 

247 

248 @staticmethod 

249 def config_str_to_dict(config_str: str) -> dict[str, str]: 

250 """Parse a configuration string to a dictionary.""" 

251 # First we filter the configuration of unwanted characters 

252 config_str = config_str.strip().replace("-", "") 

253 # Then we split the string by spaces, but conserve substrings 

254 config_list = shlex.split(config_str) 

255 # We return empty for empty input OR uneven input 

256 if config_str == "" or config_str == r"{}" or len(config_list) & 1: 

257 return {} 

258 config_dict = {} 

259 for index in range(0, len(config_list), 2): 

260 # As the value will already be a string object, no quotes are allowed in it 

261 value = config_list[index + 1].strip('"').strip("'") 

262 config_dict[config_list[index]] = value 

263 return config_dict 

264 

265 @staticmethod 

266 def parse_solver_output( 

267 solver_output: str, 

268 runsolver_configuration: list[str | Path] = None) -> dict[str, Any]: 

269 """Parse the output of the solver. 

270 

271 Args: 

272 solver_output: The output of the solver run which needs to be parsed 

273 runsolver_configuration: The runsolver configuration to wrap the solver 

274 with. If runsolver was not used this should be None. 

275 

276 Returns: 

277 Dictionary representing the parsed solver output 

278 """ 

279 if runsolver_configuration is not None: 

280 parsed_output = RunSolver.get_solver_output(runsolver_configuration, 

281 solver_output) 

282 else: 

283 parsed_output = ast.literal_eval(solver_output) 

284 # cast status attribute from str to Enum 

285 parsed_output["status"] = SolverStatus(parsed_output["status"]) 

286 # apply objectives to parsed output, runtime based objectives added here 

287 for key, value in parsed_output.items(): 

288 if key == "status": 

289 continue 

290 objective = resolve_objective(key) 

291 if objective is None: 

292 continue 

293 if objective.use_time == UseTime.NO: 

294 if objective.post_process is not None: 

295 parsed_output[objective] = objective.post_process(value) 

296 else: 

297 if runsolver_configuration is None: 

298 continue 

299 if objective.use_time == UseTime.CPU_TIME: 

300 parsed_output[key] = parsed_output["cpu_time"] 

301 else: 

302 parsed_output[key] = parsed_output["wall_time"] 

303 if objective.post_process is not None: 

304 parsed_output[key] = objective.post_process( 

305 parsed_output[key], parsed_output["cutoff_time"]) 

306 if "cutoff_time" in parsed_output: 

307 del parsed_output["cutoff_time"] 

308 return parsed_output