Coverage for sparkle/solver/solver.py: 85%
155 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1"""File to handle a solver and its directories."""
3from __future__ import annotations
4import sys
5from typing import Any
6import shlex
7import ast
8import json
9from pathlib import Path
11import runrunner as rrr
12from runrunner.local import LocalRun
13from runrunner.slurm import SlurmRun
14from runrunner.base import Status, Runner
16from sparkle.tools import pcsparser, RunSolver
17from sparkle.types import SparkleCallable, SolverStatus
18from sparkle.solver.verifier import SolutionVerifier
19from sparkle.instance import InstanceSet
20from sparkle.types import resolve_objective, SparkleObjective, UseTime
23class Solver(SparkleCallable):
24 """Class to handle a solver and its directories."""
25 meta_data = "solver_meta.txt"
26 wrapper = "sparkle_solver_wrapper.py"
28 def __init__(self: Solver,
29 directory: Path,
30 raw_output_directory: Path = None,
31 runsolver_exec: Path = None,
32 deterministic: bool = None,
33 verifier: SolutionVerifier = None) -> None:
34 """Initialize solver.
36 Args:
37 directory: Directory of the solver.
38 raw_output_directory: Directory where solver will write its raw output.
39 runsolver_exec: Path to the runsolver executable.
40 By default, runsolver in directory.
41 deterministic: Bool indicating determinism of the algorithm.
42 Defaults to False.
43 verifier: The solution verifier to use. If None, no verifier is used.
44 """
45 super().__init__(directory, runsolver_exec, raw_output_directory)
46 self.deterministic = deterministic
47 self.verifier = verifier
48 self.meta_data_file = self.directory / Solver.meta_data
50 if self.runsolver_exec is None:
51 self.runsolver_exec = self.directory / "runsolver"
52 if not self.meta_data_file.exists():
53 self.meta_data_file = None
54 if self.deterministic is None:
55 if self.meta_data_file is not None:
56 # Read the parameter from file
57 meta_dict = ast.literal_eval(self.meta_data_file.open().read())
58 self.deterministic = meta_dict["deterministic"]
59 else:
60 self.deterministic = False
62 def _get_pcs_file(self: Solver, port_type: str = None) -> Path | bool:
63 """Get path of the parameter file.
65 Returns:
66 Path to the parameter file or False if the parameter file does not exist.
67 """
68 pcs_files = [p for p in self.directory.iterdir() if p.suffix == ".pcs"
69 and (port_type is None or port_type in p.name)]
71 if len(pcs_files) == 0:
72 return False
73 if len(pcs_files) != 1:
74 # Generated PCS files present, this is a quick fix to take the original
75 pcs_files = sorted(pcs_files, key=lambda p: len(p.name))
76 return pcs_files[0]
78 def get_pcs_file(self: Solver, port_type: str = None) -> Path:
79 """Get path of the parameter file.
81 Returns:
82 Path to the parameter file. None if it can not be resolved.
83 """
84 if not (file_path := self._get_pcs_file(port_type)):
85 return None
86 return file_path
88 def read_pcs_file(self: Solver) -> bool:
89 """Checks if the pcs file can be read."""
90 pcs_file = self._get_pcs_file()
91 try:
92 parser = pcsparser.PCSParser()
93 parser.load(str(pcs_file), convention="smac")
94 return True
95 except SyntaxError:
96 pass
97 return False
99 def get_pcs(self: Solver) -> dict[str, tuple[str, str, str]]:
100 """Get the parameter content of the PCS file."""
101 if not (pcs_file := self.get_pcs_file()):
102 return None
103 parser = pcsparser.PCSParser()
104 parser.load(str(pcs_file), convention="smac")
105 return [p for p in parser.pcs.params if p["type"] == "parameter"]
107 def port_pcs(self: Solver, port_type: pcsparser.PCSConvention) -> None:
108 """Port the parameter file to the given port type."""
109 pcs_file = self.get_pcs_file()
110 parser = pcsparser.PCSParser()
111 parser.load(str(pcs_file), convention="smac")
112 target_pcs_file = pcs_file.parent / f"{pcs_file.stem}_{port_type}.pcs"
113 if target_pcs_file.exists(): # Already exists, possibly user defined
114 return
115 parser.export(convention=port_type,
116 destination=target_pcs_file)
118 def get_forbidden(self: Solver, port_type: pcsparser.PCSConvention) -> Path:
119 """Get the path to the file containing forbidden parameter combinations."""
120 if port_type == "IRACE":
121 forbidden = [p for p in self.directory.iterdir()
122 if p.name.endswith("forbidden.txt")]
123 if len(forbidden) > 0:
124 return forbidden[0]
125 return None
127 def build_cmd(self: Solver,
128 instance: str | list[str],
129 objectives: list[SparkleObjective],
130 seed: int,
131 cutoff_time: int = None,
132 configuration: dict = None,
133 log_dir: Path = None) -> list[str]:
134 """Build the solver call on an instance with a configuration.
136 Args:
137 instance: Path to the instance.
138 seed: Seed of the solver.
139 cutoff_time: Cutoff time for the solver.
140 configuration: Configuration of the solver.
142 Returns:
143 List of commands and arguments to execute the solver.
144 """
145 if configuration is None:
146 configuration = {}
147 # Ensure configuration contains required entries for each wrapper
148 configuration["solver_dir"] = str(self.directory.absolute())
149 configuration["instance"] = instance
150 configuration["seed"] = seed
151 configuration["objectives"] = ",".join([str(obj) for obj in objectives])
152 configuration["cutoff_time"] =\
153 cutoff_time if cutoff_time is not None else sys.maxsize
154 # Ensure stringification of dictionary will go correctly for key value pairs
155 configuration = {key: str(configuration[key]) for key in configuration}
156 solver_cmd = [str((self.directory / Solver.wrapper)),
157 f"'{json.dumps(configuration)}'"]
158 if log_dir is None:
159 log_dir = Path()
160 if cutoff_time is not None: # Use RunSolver
161 log_name_base = f"{Path(instance).name}_{self.name}"
162 return RunSolver.wrap_command(self.runsolver_exec,
163 solver_cmd,
164 cutoff_time,
165 log_dir,
166 log_name_base=log_name_base)
167 return solver_cmd
169 def run(self: Solver,
170 instance: str | list[str] | InstanceSet,
171 objectives: list[SparkleObjective],
172 seed: int,
173 cutoff_time: int = None,
174 configuration: dict = None,
175 run_on: Runner = Runner.LOCAL,
176 commandname: str = "run_solver",
177 sbatch_options: list[str] = None,
178 log_dir: Path = None) -> SlurmRun | list[dict[str, Any]] | dict[str, Any]:
179 """Run the solver on an instance with a certain configuration.
181 Args:
182 instance: The instance(s) to run the solver on, list in case of multi-file.
183 In case of an instance set, will run on all instances in the set.
184 seed: Seed to run the solver with. Fill with abitrary int in case of
185 determnistic solver.
186 cutoff_time: The cutoff time for the solver, measured through RunSolver.
187 If None, will be executed without RunSolver.
188 configuration: The solver configuration to use. Can be empty.
189 log_dir: Path where to place output files. Defaults to
190 self.raw_output_directory.
192 Returns:
193 Solver output dict possibly with runsolver values.
194 """
195 if log_dir is None:
196 log_dir = self.raw_output_directory
197 cmds = []
198 if isinstance(instance, InstanceSet):
199 for inst in instance.instance_paths:
200 solver_cmd = self.build_cmd(inst.absolute(),
201 objectives=objectives,
202 seed=seed,
203 cutoff_time=cutoff_time,
204 configuration=configuration,
205 log_dir=log_dir)
206 cmds.append(" ".join(solver_cmd))
207 else:
208 solver_cmd = self.build_cmd(instance,
209 objectives=objectives,
210 seed=seed,
211 cutoff_time=cutoff_time,
212 configuration=configuration,
213 log_dir=log_dir)
214 cmds.append(" ".join(solver_cmd))
215 run = rrr.add_to_queue(runner=run_on,
216 cmd=cmds,
217 name=commandname,
218 base_dir=log_dir,
219 sbatch_options=sbatch_options)
221 if isinstance(run, LocalRun):
222 run.wait()
223 # Subprocess resulted in error
224 if run.status == Status.ERROR:
225 print(f"WARNING: Solver {self.name} execution seems to have failed!\n")
226 for i, job in enumerate(run.jobs):
227 print(f"[Job {i}] The used command was: {cmds[i]}\n"
228 "The error yielded was:\n"
229 f"\t-stdout: '{run.jobs[0]._process.stdout}'\n"
230 f"\t-stderr: '{run.jobs[0]._process.stderr}'\n")
231 return {"status": SolverStatus.ERROR, }
233 solver_outputs = []
234 for i, job in enumerate(run.jobs):
235 solver_cmd = cmds[i].split(" ")
236 runsolver_configuration = None
237 if solver_cmd[0] == str(self.runsolver_exec.absolute()):
238 runsolver_configuration = solver_cmd[:11]
239 solver_output = Solver.parse_solver_output(run.jobs[i].stdout,
240 runsolver_configuration)
241 if self.verifier is not None:
242 solver_output["status"] = self.verifier.verifiy(
243 instance, Path(runsolver_configuration[-1]))
244 solver_outputs.append(solver_output)
245 return solver_outputs if len(solver_outputs) > 1 else solver_output
246 return run
248 @staticmethod
249 def config_str_to_dict(config_str: str) -> dict[str, str]:
250 """Parse a configuration string to a dictionary."""
251 # First we filter the configuration of unwanted characters
252 config_str = config_str.strip().replace("-", "")
253 # Then we split the string by spaces, but conserve substrings
254 config_list = shlex.split(config_str)
255 # We return empty for empty input OR uneven input
256 if config_str == "" or config_str == r"{}" or len(config_list) & 1:
257 return {}
258 config_dict = {}
259 for index in range(0, len(config_list), 2):
260 # As the value will already be a string object, no quotes are allowed in it
261 value = config_list[index + 1].strip('"').strip("'")
262 config_dict[config_list[index]] = value
263 return config_dict
265 @staticmethod
266 def parse_solver_output(
267 solver_output: str,
268 runsolver_configuration: list[str | Path] = None) -> dict[str, Any]:
269 """Parse the output of the solver.
271 Args:
272 solver_output: The output of the solver run which needs to be parsed
273 runsolver_configuration: The runsolver configuration to wrap the solver
274 with. If runsolver was not used this should be None.
276 Returns:
277 Dictionary representing the parsed solver output
278 """
279 if runsolver_configuration is not None:
280 parsed_output = RunSolver.get_solver_output(runsolver_configuration,
281 solver_output)
282 else:
283 parsed_output = ast.literal_eval(solver_output)
284 # cast status attribute from str to Enum
285 parsed_output["status"] = SolverStatus(parsed_output["status"])
286 # apply objectives to parsed output, runtime based objectives added here
287 for key, value in parsed_output.items():
288 if key == "status":
289 continue
290 objective = resolve_objective(key)
291 if objective is None:
292 continue
293 if objective.use_time == UseTime.NO:
294 if objective.post_process is not None:
295 parsed_output[objective] = objective.post_process(value)
296 else:
297 if runsolver_configuration is None:
298 continue
299 if objective.use_time == UseTime.CPU_TIME:
300 parsed_output[key] = parsed_output["cpu_time"]
301 else:
302 parsed_output[key] = parsed_output["wall_time"]
303 if objective.post_process is not None:
304 parsed_output[key] = objective.post_process(
305 parsed_output[key], parsed_output["cutoff_time"])
306 if "cutoff_time" in parsed_output:
307 del parsed_output["cutoff_time"]
308 return parsed_output