Coverage for sparkle/solver/solver.py: 84%
198 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
1"""File to handle a solver and its directories."""
2from __future__ import annotations
3import sys
4from typing import Any
5import shlex
6import ast
7import json
8from pathlib import Path
10from ConfigSpace import ConfigurationSpace
12import runrunner as rrr
13from runrunner.local import LocalRun
14from runrunner.slurm import Run, SlurmRun
15from runrunner.base import Status, Runner
17from sparkle.tools.parameters import PCSConverter, PCSConvention
18from sparkle.tools import RunSolver
19from sparkle.types import SparkleCallable, SolverStatus
20from sparkle.solver import verifiers
21from sparkle.instance import InstanceSet
22from sparkle.structures import PerformanceDataFrame
23from sparkle.types import resolve_objective, SparkleObjective, UseTime
26class Solver(SparkleCallable):
27 """Class to handle a solver and its directories."""
28 meta_data = "solver_meta.txt"
29 wrapper = "sparkle_solver_wrapper.py"
30 solver_cli = Path(__file__).parent / "solver_cli.py"
32 def __init__(self: Solver,
33 directory: Path,
34 raw_output_directory: Path = None,
35 runsolver_exec: Path = None,
36 deterministic: bool = None,
37 verifier: verifiers.SolutionVerifier = None) -> None:
38 """Initialize solver.
40 Args:
41 directory: Directory of the solver.
42 raw_output_directory: Directory where solver will write its raw output.
43 runsolver_exec: Path to the runsolver executable.
44 By default, runsolver in directory.
45 deterministic: Bool indicating determinism of the algorithm.
46 Defaults to False.
47 verifier: The solution verifier to use. If None, no verifier is used.
48 """
49 super().__init__(directory, runsolver_exec, raw_output_directory)
50 self.deterministic = deterministic
51 self.verifier = verifier
52 self._pcs_file: Path = None
54 meta_data_file = self.directory / Solver.meta_data
55 if self.runsolver_exec is None:
56 self.runsolver_exec = self.directory / "runsolver"
57 if meta_data_file.exists():
58 meta_data = ast.literal_eval(meta_data_file.open().read())
59 # We only override the deterministic and verifier from file if not set
60 if self.deterministic is None:
61 if ("deterministic" in meta_data
62 and meta_data["deterministic"] is not None):
63 self.deterministic = meta_data["deterministic"]
64 if self.verifier is None and "verifier" in meta_data:
65 if isinstance(meta_data["verifier"], tuple): # File verifier
66 self.verifier = verifiers.mapping[meta_data["verifier"][0]](
67 Path(meta_data["verifier"][1])
68 )
69 elif meta_data["verifier"] in verifiers.mapping:
70 self.verifier = verifiers.mapping[meta_data["verifier"]]
71 if self.deterministic is None: # Default to False
72 self.deterministic = False
74 def __str__(self: Solver) -> str:
75 """Return the sting representation of the solver."""
76 return self.name
78 @property
79 def pcs_file(self: Solver) -> Path:
80 """Get path of the parameter file."""
81 if self._pcs_file is None:
82 files = sorted([p for p in self.directory.iterdir() if p.suffix == ".pcs"])
83 if len(files) == 0:
84 return None
85 self._pcs_file = files[0]
86 return self._pcs_file
88 def get_pcs_file(self: Solver, port_type: PCSConvention) -> Path:
89 """Get path of the parameter file of a specific convention.
91 Args:
92 port_type: Port type of the parameter file. If None, will return the
93 file with the shortest name.
95 Returns:
96 Path to the parameter file. None if it can not be resolved.
97 """
98 pcs_files = sorted([p for p in self.directory.iterdir() if p.suffix == ".pcs"])
99 if port_type is None:
100 return pcs_files[0]
101 for file in pcs_files:
102 if port_type == PCSConverter.get_convention(file):
103 return file
104 return None
106 def read_pcs_file(self: Solver) -> bool:
107 """Checks if the pcs file can be read."""
108 # TODO: Should be a .validate method instead
109 return PCSConverter.get_convention(self.pcs_file) is not None
111 def get_cs(self: Solver) -> ConfigurationSpace:
112 """Get the ConfigurationSpace of the PCS file."""
113 if not self.pcs_file:
114 return None
115 return PCSConverter.parse(self.pcs_file)
117 def port_pcs(self: Solver, port_type: PCSConvention) -> None:
118 """Port the parameter file to the given port type."""
119 target_pcs_file =\
120 self.pcs_file.parent / f"{self.pcs_file.stem}_{port_type.name}.pcs"
121 if target_pcs_file.exists(): # Already exists, possibly user defined
122 return
123 PCSConverter.export(self.get_cs(), port_type, target_pcs_file)
125 def build_cmd(self: Solver,
126 instance: str | list[str],
127 objectives: list[SparkleObjective],
128 seed: int,
129 cutoff_time: int = None,
130 configuration: dict = None,
131 log_dir: Path = None) -> list[str]:
132 """Build the solver call on an instance with a configuration.
134 Args:
135 instance: Path to the instance.
136 seed: Seed of the solver.
137 cutoff_time: Cutoff time for the solver.
138 configuration: Configuration of the solver.
140 Returns:
141 List of commands and arguments to execute the solver.
142 """
143 if configuration is None:
144 configuration = {}
145 # Ensure configuration contains required entries for each wrapper
146 configuration["solver_dir"] = str(self.directory.absolute())
147 configuration["instance"] = instance
148 configuration["seed"] = seed
149 configuration["objectives"] = ",".join([str(obj) for obj in objectives])
150 configuration["cutoff_time"] =\
151 cutoff_time if cutoff_time is not None else sys.maxsize
152 if "configuration_id" in configuration:
153 del configuration["configuration_id"]
154 # Ensure stringification of dictionary will go correctly for key value pairs
155 configuration = {key: str(configuration[key]) for key in configuration}
156 solver_cmd = [str((self.directory / Solver.wrapper)),
157 f"'{json.dumps(configuration)}'"]
158 if log_dir is None:
159 log_dir = Path()
160 if cutoff_time is not None: # Use RunSolver
161 log_name_base = f"{Path(instance).name}_{self.name}"
162 return RunSolver.wrap_command(self.runsolver_exec,
163 solver_cmd,
164 cutoff_time,
165 log_dir,
166 log_name_base=log_name_base)
167 return solver_cmd
169 def run(self: Solver,
170 instances: str | list[str] | InstanceSet | list[InstanceSet],
171 objectives: list[SparkleObjective],
172 seed: int,
173 cutoff_time: int = None,
174 configuration: dict = None,
175 run_on: Runner = Runner.LOCAL,
176 sbatch_options: list[str] = None,
177 slurm_prepend: str | list[str] | Path = None,
178 log_dir: Path = None,
179 ) -> SlurmRun | list[dict[str, Any]] | dict[str, Any]:
180 """Run the solver on an instance with a certain configuration.
182 Args:
183 instance: The instance(s) to run the solver on, list in case of multi-file.
184 In case of an instance set, will run on all instances in the set.
185 seed: Seed to run the solver with. Fill with abitrary int in case of
186 determnistic solver.
187 cutoff_time: The cutoff time for the solver, measured through RunSolver.
188 If None, will be executed without RunSolver.
189 configuration: The solver configuration to use. Can be empty.
190 run_on: Whether to run on slurm or locally.
191 sbatch_options: The sbatch options to use.
192 slurm_prepend: The script to prepend to a slurm script.
193 log_dir: The log directory to use.
195 Returns:
196 Solver output dict possibly with runsolver values.
197 """
198 if log_dir is None:
199 log_dir = self.raw_output_directory
200 cmds = []
201 instances = [instances] if not isinstance(instances, list) else instances
202 set_label = instances.name if isinstance(instances, InstanceSet) else "instances"
203 for instance in instances:
204 paths = instance.instace_paths if isinstance(instance,
205 InstanceSet) else [instance]
206 for instance_path in paths:
207 solver_cmd = self.build_cmd(instance_path,
208 objectives=objectives,
209 seed=seed,
210 cutoff_time=cutoff_time,
211 configuration=configuration,
212 log_dir=log_dir)
213 cmds.append(" ".join(solver_cmd))
215 commandname = f"Run Solver: {self.name} on {set_label}"
216 run = rrr.add_to_queue(runner=run_on,
217 cmd=cmds,
218 name=commandname,
219 base_dir=log_dir,
220 sbatch_options=sbatch_options,
221 prepend=slurm_prepend)
223 if isinstance(run, LocalRun):
224 run.wait()
225 import time
226 time.sleep(5)
227 # Subprocess resulted in error
228 if run.status == Status.ERROR:
229 print(f"WARNING: Solver {self.name} execution seems to have failed!\n")
230 for i, job in enumerate(run.jobs):
231 print(f"[Job {i}] The used command was: {cmds[i]}\n"
232 "The error yielded was:\n"
233 f"\t-stdout: '{run.jobs[0]._process.stdout}'\n"
234 f"\t-stderr: '{run.jobs[0]._process.stderr}'\n")
235 return {"status": SolverStatus.ERROR, }
237 solver_outputs = []
238 for i, job in enumerate(run.jobs):
239 solver_cmd = cmds[i].split(" ")
240 solver_output = Solver.parse_solver_output(run.jobs[i].stdout,
241 solver_call=solver_cmd,
242 objectives=objectives,
243 verifier=self.verifier)
244 solver_outputs.append(solver_output)
245 return solver_outputs if len(solver_outputs) > 1 else solver_output
246 return run
248 def run_performance_dataframe(
249 self: Solver,
250 instances: str | list[str] | InstanceSet,
251 run_ids: int | list[int] | range[int, int]
252 | list[list[int]] | list[range[int]],
253 performance_dataframe: PerformanceDataFrame,
254 cutoff_time: int = None,
255 objective: SparkleObjective = None,
256 train_set: InstanceSet = None,
257 sbatch_options: list[str] = None,
258 slurm_prepend: str | list[str] | Path = None,
259 dependencies: list[SlurmRun] = None,
260 log_dir: Path = None,
261 base_dir: Path = None,
262 job_name: str = None,
263 run_on: Runner = Runner.SLURM) -> Run:
264 """Run the solver from and place the results in the performance dataframe.
266 This in practice actually runs Solver.run, but has a little script before/after,
267 to read and write to the performance dataframe.
269 Args:
270 instance: The instance(s) to run the solver on. In case of an instance set,
271 or list, will create a job for all instances in the set/list.
272 run_ids: The run indices to use in the performance dataframe.
273 If int, will run only this id for all instances. If a list of integers
274 or range, will run all run indexes for all instances.
275 If a list of lists or list of ranges, will assume the runs are paired
276 with the instances, e.g. will use sequence 1 for instance 1, ...
277 performance_dataframe: The performance dataframe to use.
278 cutoff_time: The cutoff time for the solver, measured through RunSolver.
279 objective: The objective to use, only relevant for train set best config
280 determining
281 train_set: The training set to use. If present, will determine the best
282 configuration of the solver using these instances and run with it on
283 all instances in the instance argument.
284 sbatch_options: List of slurm batch options to use
285 slurm_prepend: Slurm script to prepend to the sbatch
286 dependencies: List of slurm runs to use as dependencies
287 log_dir: Path where to place output files. Defaults to
288 self.raw_output_directory.
289 base_dir: Path where to place output files.
290 job_name: Name of the job
291 If None, will generate a name based on Solver and Instances
292 run_on: On which platform to run the jobs. Default: Slurm.
294 Returns:
295 SlurmRun or Local run of the job.
296 """
297 instances = [instances] if isinstance(instances, str) else instances
298 set_name = "instances"
299 if isinstance(instances, InstanceSet):
300 set_name = instances.name
301 instances = [str(i) for i in instances.instance_paths]
302 # Resolve run_ids to which run indices to use for which instance
303 if isinstance(run_ids, int):
304 run_ids = [[run_ids]] * len(instances)
305 elif isinstance(run_ids, range):
306 run_ids = [list(run_ids)] * len(instances)
307 elif isinstance(run_ids, list):
308 if all(isinstance(i, int) for i in run_ids):
309 run_ids = [run_ids] * len(instances)
310 elif all(isinstance(i, range) for i in run_ids):
311 run_ids = [list(i) for i in run_ids]
312 elif all(isinstance(i, list) for i in run_ids):
313 pass
314 else:
315 raise TypeError(f"Invalid type combination for run_ids: {type(run_ids)}")
316 objective_arg = f"--target-objective {objective.name}" if objective else ""
317 train_arg =\
318 ",".join([str(i) for i in train_set.instance_paths]) if train_set else ""
319 cmds = [
320 f"python3 {Solver.solver_cli} "
321 f"--solver {self.directory} "
322 f"--instance {instance} "
323 f"--run-index {run_index} "
324 f"--performance-dataframe {performance_dataframe.csv_filepath} "
325 f"--cutoff-time {cutoff_time} "
326 f"--log-dir {log_dir} "
327 f"{objective_arg} "
328 f"{'--best-configuration-instances' if train_set else ''} {train_arg}"
329 for instance, run_indices in zip(instances, run_ids)
330 for run_index in run_indices]
331 job_name = f"Run: {self.name} on {set_name}" if job_name is None else job_name
332 r = rrr.add_to_queue(
333 runner=run_on,
334 cmd=cmds,
335 name=job_name,
336 base_dir=base_dir,
337 sbatch_options=sbatch_options,
338 slurm_prepend=slurm_prepend,
339 dependencies=dependencies
340 )
341 if run_on == Runner.LOCAL:
342 r.wait()
343 return r
345 @staticmethod
346 def config_str_to_dict(config_str: str) -> dict[str, str]:
347 """Parse a configuration string to a dictionary."""
348 # First we filter the configuration of unwanted characters
349 config_str = config_str.strip().replace("-", "")
350 # Then we split the string by spaces, but conserve substrings
351 config_list = shlex.split(config_str)
352 # We return empty for empty input OR uneven input
353 if config_str == "" or config_str == r"{}" or len(config_list) & 1:
354 return {}
355 config_dict = {}
356 for index in range(0, len(config_list), 2):
357 # As the value will already be a string object, no quotes are allowed in it
358 value = config_list[index + 1].strip('"').strip("'")
359 config_dict[config_list[index]] = value
360 return config_dict
362 @staticmethod
363 def parse_solver_output(
364 solver_output: str,
365 solver_call: list[str | Path] = None,
366 objectives: list[SparkleObjective] = None,
367 verifier: verifiers.SolutionVerifier = None) -> dict[str, Any]:
368 """Parse the output of the solver.
370 Args:
371 solver_output: The output of the solver run which needs to be parsed
372 solver_call: The solver call used to run the solver
373 objectives: The objectives to apply to the solver output
374 verifier: The verifier to check the solver output
376 Returns:
377 Dictionary representing the parsed solver output
378 """
379 used_runsolver = False
380 if solver_call is not None and len(solver_call) > 2:
381 used_runsolver = True
382 parsed_output = RunSolver.get_solver_output(solver_call,
383 solver_output)
384 else:
385 parsed_output = ast.literal_eval(solver_output)
386 # cast status attribute from str to Enum
387 parsed_output["status"] = SolverStatus(parsed_output["status"])
388 # Apply objectives to parsed output, runtime based objectives added here
389 if verifier is not None and used_runsolver:
390 # Horrible hack to get the instance from the solver input
391 solver_call_str: str = " ".join(solver_call)
392 solver_input_str = solver_call_str.split(Solver.wrapper, maxsplit=1)[1]
393 solver_input_str = solver_input_str[solver_input_str.index("{"):
394 solver_input_str.index("}") + 1]
395 solver_input = ast.literal_eval(solver_input_str)
396 target_instance = Path(solver_input["instance"])
397 parsed_output["status"] = verifier.verify(
398 target_instance, parsed_output, solver_call)
400 # Create objective map
401 objectives = {o.stem: o for o in objectives} if objectives else {}
402 removable_keys = ["cutoff_time"] # Keys to remove
404 # apply objectives to parsed output, runtime based objectives added here
405 for key, value in parsed_output.items():
406 if objectives and key in objectives:
407 objective = objectives[key]
408 removable_keys.append(key) # We translate it into the full name
409 else:
410 objective = resolve_objective(key)
411 # If not found in objectives, resolve to which objective the output belongs
412 if objective is None: # Could not parse, skip
413 continue
414 if objective.use_time == UseTime.NO:
415 if objective.post_process is not None:
416 parsed_output[key] = objective.post_process(value)
417 else:
418 if not used_runsolver:
419 continue
420 if objective.use_time == UseTime.CPU_TIME:
421 parsed_output[key] = parsed_output["cpu_time"]
422 else:
423 parsed_output[key] = parsed_output["wall_time"]
424 if objective.post_process is not None:
425 parsed_output[key] = objective.post_process(
426 parsed_output[key],
427 parsed_output["cutoff_time"],
428 parsed_output["status"])
430 # Replace or remove keys based on the objective names
431 for key in removable_keys:
432 if key in parsed_output:
433 if key in objectives:
434 # Map the result to the objective
435 parsed_output[objectives[key].name] = parsed_output[key]
436 if key != objectives[key].name: # Only delete actual mappings
437 del parsed_output[key]
438 else:
439 del parsed_output[key]
440 return parsed_output