Coverage for sparkle/solver/validator.py: 95%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1"""File containing the Validator class.""" 

2 

3from __future__ import annotations 

4 

5import sys 

6from pathlib import Path 

7import csv 

8import ast 

9import runrunner as rrr 

10from runrunner import Runner, Run 

11 

12from sparkle.solver import Solver 

13from sparkle.instance import InstanceSet 

14from sparkle.types import SparkleObjective, resolve_objective 

15from sparkle.tools import RunSolver 

16 

17 

18class Validator(): 

19 """Class to handle the validation of solvers on instance sets.""" 

20 def __init__(self: Validator, 

21 out_dir: Path = Path(), 

22 tmp_out_dir: Path = Path()) -> None: 

23 """Construct the validator.""" 

24 self.out_dir = out_dir 

25 self.tmp_out_dir = tmp_out_dir 

26 

27 def validate(self: Validator, 

28 solvers: list[Path] | list[Solver] | Solver | Path, 

29 configurations: list[dict] | dict | Path, 

30 instance_sets: list[InstanceSet], 

31 objectives: list[SparkleObjective], 

32 cut_off: int, 

33 subdir: Path = None, 

34 dependency: list[Run] | Run = None, 

35 sbatch_options: list[str] = [], 

36 run_on: Runner = Runner.SLURM) -> Run: 

37 """Validate a list of solvers (with configurations) on a set of instances. 

38 

39 Args: 

40 solvers: list of solvers to validate 

41 configurations: list of configurations for each solver we validate. 

42 If a path is supplied, will use each line as a configuration. 

43 instance_sets: set of instance sets on which we want to validate each solver 

44 objectives: list of objectives to validate 

45 cut_off: maximum run time for the solver per instance 

46 subdir: The subdir where to place the output in the outputdir. If None, 

47 a semi-unique combination of solver_instanceset is created. 

48 dependency: Jobs to wait for before executing the validation. 

49 sbatch_options: list of slurm batch options 

50 run_on: whether to run on SLURM or local 

51 """ 

52 if not isinstance(solvers, list) and isinstance(configurations, list): 

53 # If we receive one solver but multiple configurations, we cas the 

54 # Solvers to a list of the same length 

55 solvers = [solvers] * len(configurations) 

56 elif not isinstance(configurations, list) and isinstance(solvers, list): 

57 # If there is only one configuration, we cast it to a list of the same 

58 # length as the solver list 

59 configurations = [configurations] * len(solvers) 

60 if not isinstance(solvers, list) or len(configurations) != len(solvers): 

61 print("Error: Number of solvers and configurations does not match!") 

62 sys.exit(-1) 

63 # Ensure we have the object representation of solvers 

64 solvers = [Solver(s) if isinstance(s, Path) else s for s in solvers] 

65 cmds = [] 

66 for index, (solver, config) in enumerate(zip(solvers, configurations)): 

67 if config is None: 

68 config = {} 

69 elif isinstance(config, Path): 

70 # Point to the config line in file 

71 config = {"config_path": config} 

72 for instance_set in instance_sets: 

73 if subdir is None: 

74 out_path = self.out_dir / f"{solver.name}_{instance_set.name}" 

75 else: 

76 out_path = self.out_dir / subdir 

77 out_path.mkdir(exist_ok=True) 

78 for instance_path in instance_set._instance_paths: 

79 cmds.append(" ".join( 

80 solver.build_cmd(instance=instance_path.absolute(), 

81 objectives=objectives, 

82 seed=index, 

83 cutoff_time=cut_off, 

84 configuration=config, 

85 log_dir=out_path))) 

86 return rrr.add_to_queue( 

87 runner=run_on, 

88 cmd=cmds, 

89 name=f"Validation: {','.join(set([s.name for s in solvers]))} on " 

90 f"{','.join([i.name for i in instance_sets])}", 

91 base_dir=self.tmp_out_dir, 

92 dependencies=dependency, 

93 sbatch_options=sbatch_options, 

94 ) 

95 

96 def retrieve_raw_results(self: Validator, 

97 solver: Solver, 

98 instance_sets: InstanceSet | list[InstanceSet], 

99 subdir: Path = None, 

100 log_dir: Path = None) -> None: 

101 """Checks the raw results of a given solver for a specific instance_set. 

102 

103 Writes the raw results to a unified CSV file for the resolve/instance_set 

104 combination. 

105 

106 Args: 

107 solver: The solver for which to check the raw result path 

108 instance_sets: The set of instances for which to retrieve the results 

109 subdir: Subdir where the CSV is to be placed, passed to the append method. 

110 log_dir: The directory to search for log files. If none, defaults to 

111 the log directory of the Solver. 

112 """ 

113 if isinstance(instance_sets, InstanceSet): 

114 instance_sets = [instance_sets] 

115 if log_dir is None: 

116 log_dir = solver.raw_output_directory 

117 for res in log_dir.iterdir(): 

118 if res.suffix != ".rawres": 

119 continue 

120 solver_args = RunSolver.get_solver_args(res.with_suffix(".log")) 

121 solver_args = ast.literal_eval(solver_args) 

122 instance_path = Path(solver_args["instance"]) 

123 # Remove default args 

124 if "config_path" in solver_args: 

125 # The actual solver configuration can be found elsewhere 

126 row_idx = int(solver_args["seed"]) 

127 config_path = Path(solver_args["config_path"]) 

128 if not config_path.exists(): 

129 config_path = log_dir / config_path 

130 config_str = config_path.open("r").readlines()[row_idx] 

131 solver_args = Solver.config_str_to_dict(config_str) 

132 else: 

133 for def_arg in ["instance", "solver_dir", "cutoff_time", 

134 "seed", "objectives"]: 

135 if def_arg in solver_args: 

136 del solver_args[def_arg] 

137 solver_args = str(solver_args).replace('"', "'") 

138 

139 for instance_set in instance_sets: 

140 if instance_path.name in instance_set._instance_names: 

141 out_dict = Solver.parse_solver_output( 

142 "", 

143 ["-o", res, 

144 "-v", res.with_suffix(".val"), 

145 "-w", res.with_suffix(".log")]) 

146 self.append_entry_to_csv(solver.name, 

147 solver_args, 

148 instance_set, 

149 instance_path.name, 

150 solver_output=out_dict, 

151 subdir=subdir) 

152 res.unlink() 

153 res.with_suffix(".val").unlink(missing_ok=True) 

154 res.with_suffix(".log").unlink(missing_ok=True) 

155 

156 def get_validation_results(self: Validator, 

157 solver: Solver, 

158 instance_set: InstanceSet, 

159 source_dir: Path = None, 

160 subdir: Path = None, 

161 config: str = None) -> list[list[str]]: 

162 """Query the results of the validation of solver on instance_set. 

163 

164 Args: 

165 solver: Solver object 

166 instance_set: Instance set 

167 source_dir: Path where to look for any unprocessed output. 

168 By default, look in the solver's tmp dir. 

169 subdir: Path where to place the .csv file subdir. By default will be 

170 'self.outputdir/solver.name_instanceset.name/validation.csv' 

171 config: Path to the configuration if the solver was configured, None 

172 otherwise 

173 Returns 

174 A list of row lists with string values 

175 """ 

176 if source_dir is None: 

177 source_dir = self.out_dir / f"{solver.name}_{instance_set.name}" 

178 if any(x.suffix == ".rawres" for x in source_dir.iterdir()): 

179 self.retrieve_raw_results( 

180 solver, instance_set, subdir=subdir, log_dir=source_dir) 

181 if subdir is None: 

182 subdir = Path(f"{solver.name}_{instance_set.name}") 

183 csv_file = self.out_dir / subdir / "validation.csv" 

184 csv_data = [line for line in csv.reader(csv_file.open("r"))] 

185 header = csv_data[0] 

186 if config is not None: 

187 # We filter on the config string by subdict 

188 if isinstance(config, str): 

189 config = Solver.config_str_to_dict(config) 

190 csv_data = [line for line in csv_data[1:] if 

191 config.items() == ast.literal_eval(line[1]).items()] 

192 csv_data.insert(0, header) 

193 return csv_data 

194 

195 def append_entry_to_csv(self: Validator, 

196 solver: str, 

197 config_str: str, 

198 instance_set: InstanceSet, 

199 instance: str, 

200 solver_output: dict, 

201 subdir: Path = None) -> None: 

202 """Append a validation result as a row to a CSV file.""" 

203 if subdir is None: 

204 subdir = Path(f"{solver}_{instance_set.name}") 

205 out_dir = self.out_dir / subdir 

206 if not out_dir.exists(): 

207 out_dir.mkdir(parents=True) 

208 csv_file = out_dir / "validation.csv" 

209 status = solver_output["status"] 

210 cpu_time = solver_output["cpu_time"] 

211 wall_time = solver_output["wall_time"] 

212 del solver_output["status"] 

213 del solver_output["cpu_time"] 

214 del solver_output["wall_time"] 

215 sorted_keys = sorted(solver_output) 

216 objectives = [resolve_objective(key) for key in sorted_keys] 

217 objectives = [o for o in objectives if o is not None] 

218 if not csv_file.exists(): 

219 # Write header 

220 header = ["Solver", "Configuration", "InstanceSet", "Instance", "Status", 

221 "CPU Time", "Wallclock Time"] + [o.name for o in objectives] 

222 with csv_file.open("w") as out: 

223 csv.writer(out).writerow((header)) 

224 values = [solver, config_str, instance_set.name, instance, status, cpu_time, 

225 wall_time] + [solver_output[o.name] for o in objectives] 

226 with csv_file.open("a") as out: 

227 writer = csv.writer(out) 

228 writer.writerow(values)