Coverage for src / sparkle / selector / extractor.py: 81%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 15:31 +0000

1"""Methods regarding feature extractors.""" 

2 

3from __future__ import annotations 

4from typing import Any 

5from pathlib import Path 

6import ast 

7import re 

8import subprocess 

9 

10import runrunner as rrr 

11from runrunner.base import Status, Runner 

12from runrunner.local import Run, LocalRun 

13 

14from sparkle.types import SparkleCallable, SolverStatus 

15from sparkle.structures import FeatureDataFrame 

16from sparkle.tools import RunSolver 

17from sparkle.instance import InstanceSet 

18 

19 

20class Extractor(SparkleCallable): 

21 """Extractor base class for extracting features from instances.""" 

22 

23 wrapper_file_name = "sparkle_extractor_wrapper" 

24 extractor_cli = Path(__file__).parent / "extractor_cli.py" 

25 

26 output_pattern = re.compile( 

27 r"(?P<timestamp1>\d+\.\d+)\/(?P<timestamp2>\d+\.\d+)\s*(?P<output>\[[^\]]*\])(?:\r)?.*" 

28 ) 

29 

30 def __init__(self: Extractor, directory: Path) -> None: 

31 """Initialize solver. 

32 

33 Args: 

34 directory: Directory of the solver. 

35 runsolver_exec: Path to the runsolver executable. 

36 By default, runsolver in directory. 

37 """ 

38 super().__init__(directory) 

39 self._features = None 

40 self._feature_groups = None 

41 self._groupwise_computation = None 

42 self._wrapper: Path = None 

43 

44 def __str__(self: Extractor) -> str: 

45 """Return the string representation of the extractor.""" 

46 return self.name 

47 

48 def __repr__(self: Extractor) -> str: 

49 """Return detailed representation of the extractor.""" 

50 return ( 

51 f"{self.name}:\n" 

52 f"\t- Directory: {self.directory}\n" 

53 f"\t- Wrapper: {self.wrapper}\n" 

54 f"\t- # Feature Groups: {len(self.feature_groups)}\n" 

55 f"\t- Output Dimension (# Features): {self.output_dimension}\n" 

56 f"\t- Groupwise Computation Enabled: {self.groupwise_computation}" 

57 ) 

58 

59 @property 

60 def wrapper(self: Extractor) -> Path: 

61 """Determines the Path to the Extractor wrapper.""" 

62 if self._wrapper is None: 

63 if (self.directory / f"{Extractor.wrapper_file_name}.sh").exists(): 

64 self._wrapper = self.directory / f"{Extractor.wrapper_file_name}.sh" 

65 elif (self.directory / f"{Extractor.wrapper_file_name}.py").exists(): 

66 self._wrapper = self.directory / f"{Extractor.wrapper_file_name}.py" 

67 return self._wrapper 

68 

69 @property 

70 def features(self: Extractor) -> list[tuple[str, str]]: 

71 """Determines the features of the extractor.""" 

72 if self._features is None: 

73 extractor_process = subprocess.run( 

74 [self.wrapper, "-features"], capture_output=True 

75 ) 

76 self._features = ast.literal_eval(extractor_process.stdout.decode()) 

77 return self._features 

78 

79 @property 

80 def feature_groups(self: Extractor) -> list[str]: 

81 """Returns the various feature groups the Extractor has.""" 

82 if self._feature_groups is None: 

83 self._feature_groups = list(set([group for group, _ in self.features])) 

84 return self._feature_groups 

85 

86 @property 

87 def output_dimension(self: Extractor) -> int: 

88 """The size of the output vector of the extractor.""" 

89 return len(self.features) 

90 

91 @property 

92 def groupwise_computation(self: Extractor) -> bool: 

93 """Determines if you can call the extractor per group for parallelisation.""" 

94 if self._groupwise_computation is None: 

95 extractor_help = subprocess.run([self.wrapper, "-h"], capture_output=True) 

96 # Not the cleanest / most precise way to determine this 

97 self._groupwise_computation = ( 

98 "-feature_group" in extractor_help.stdout.decode() 

99 ) 

100 return self._groupwise_computation 

101 

102 def build_cmd( 

103 self: Extractor, 

104 instance: Path | list[Path], 

105 feature_group: str = None, 

106 output_file: Path = None, 

107 cutoff_time: int = None, 

108 log_dir: Path = None, 

109 ) -> list[str]: 

110 """Builds a command line string seperated by space. 

111 

112 Args: 

113 instance: The instance to run on 

114 feature_group: The optional feature group to run the extractor for. 

115 output_file: Optional file to write the output to. 

116 runsolver_args: The arguments for runsolver. If not present, 

117 will run the extractor without runsolver. 

118 cutoff_time: The maximum runtime. 

119 log_dir: Directory path for logs. 

120 

121 Returns: 

122 The command seperated per item in the list. 

123 """ 

124 cmd_list_extractor = [] 

125 if not isinstance(instance, list): 

126 instance = [instance] 

127 cmd_list_extractor = [ 

128 f"{self.wrapper}", 

129 "-extractor_dir", 

130 f"{self.directory}/", 

131 "-instance_file", 

132 ] + [str(file) for file in instance] 

133 if feature_group is not None: 

134 cmd_list_extractor += ["-feature_group", feature_group] 

135 if output_file is not None: 

136 cmd_list_extractor += ["-output_file", str(output_file)] 

137 if cutoff_time is not None: 

138 # Extractor handles output file itself 

139 return RunSolver.wrap_command( 

140 self.runsolver_exec, 

141 cmd_list_extractor, 

142 cutoff_time, 

143 log_dir, 

144 log_name_base=self.name, 

145 raw_results_file=False, 

146 ) 

147 return cmd_list_extractor 

148 

149 def run( 

150 self: Extractor, 

151 instance: Path | list[Path], 

152 feature_group: str = None, 

153 output_file: Path = None, 

154 cutoff_time: int = None, 

155 log_dir: Path = None, 

156 ) -> list[list[Any]] | list[Any] | None: 

157 """Runs an extractor job with Runrunner. 

158 

159 Args: 

160 extractor_path: Path to the executable 

161 instance: Path to the instance to run on 

162 feature_group: The feature group to compute. Must be supported by the 

163 extractor to use. 

164 output_file: Target output. If None, piped to the RunRunner job. 

165 cutoff_time: CPU cutoff time in seconds 

166 log_dir: Directory to write logs. Defaults to CWD. 

167 

168 Returns: 

169 The features or None if an output file is used, or features can not be found. 

170 

171 Raises: 

172 TimeoutError: If the extractor was cut off by RunSolver. 

173 RuntimeError: If the extractor failed to run or produced no features. 

174 """ 

175 log_dir = Path() if log_dir is None else log_dir 

176 if feature_group is not None and not self.groupwise_computation: 

177 # This extractor cannot handle groups, compute all features 

178 feature_group = None 

179 cmd_extractor = self.build_cmd( 

180 instance, feature_group, output_file, cutoff_time, log_dir 

181 ) 

182 

183 # Find runsolver values file if applicable 

184 runsolver_values_path = None 

185 if cutoff_time is not None: 

186 for flag in ("-v", "--var"): 

187 if flag in cmd_extractor: 

188 flag_index = cmd_extractor.index(flag) 

189 if flag_index + 1 < len(cmd_extractor): 

190 runsolver_values_path = Path(cmd_extractor[flag_index + 1]) 

191 break 

192 

193 def decode_stream(stream: Any) -> str: 

194 """Normalize stdout/stderr to a string, decoding bytes and handling None.""" 

195 if stream is None: 

196 return "" 

197 if isinstance(stream, bytes): 

198 # use replace to substitute undecodable bytes with replacement character 

199 return stream.decode(errors="replace") 

200 return str(stream) 

201 

202 run_on = Runner.LOCAL # TODO: Let this function also handle Slurm runs 

203 extractor_run = rrr.add_to_queue(runner=run_on, cmd=" ".join(cmd_extractor)) 

204 if isinstance(extractor_run, LocalRun): 

205 extractor_run.wait() 

206 

207 job_logs = [ 

208 (decode_stream(job.stdout), decode_stream(job.stderr)) 

209 for job in extractor_run.jobs 

210 ] 

211 

212 if ( 

213 runsolver_values_path is not None 

214 and RunSolver.get_status(runsolver_values_path, None) 

215 == SolverStatus.TIMEOUT 

216 ): 

217 raise TimeoutError( 

218 f"{self.name} timed out after {cutoff_time}s on {instance}." 

219 ) 

220 

221 if extractor_run.status == Status.ERROR: 

222 error_details = "\n".join( 

223 f"Job {i} stdout:\n{stdout or '<empty>'}\nstderr:\n{stderr or '<empty>'}" 

224 for i, (stdout, stderr) in enumerate(job_logs) 

225 ) 

226 raise RuntimeError( 

227 f"{self.name} failed to compute features for {instance}.\n" 

228 f"{error_details}" 

229 ) 

230 output = [] 

231 for job, (stdout, _) in zip(extractor_run.jobs, job_logs): 

232 # RunRunner adds a timestamp before the statement 

233 match = self.output_pattern.match(stdout) 

234 if match: 

235 output.append(ast.literal_eval(match.group("output"))) 

236 if not output and output_file is None: 

237 output_details = "\n".join( 

238 f"Job {i} stdout:\n{stdout or '<empty>'}\nstderr:\n{stderr or '<empty>'}" 

239 for i, (stdout, stderr) in enumerate(job_logs) 

240 ) 

241 raise RuntimeError( 

242 f"{self.name} did not produce feature values for {instance}.\n" 

243 f"{output_details}" 

244 ) 

245 if len(output) == 1: 

246 return output[0] 

247 return output 

248 return None 

249 

250 def run_cli( 

251 self: Extractor, 

252 instance_set: InstanceSet | list[Path], 

253 feature_dataframe: FeatureDataFrame, 

254 cutoff_time: int, 

255 feature_group: str = None, 

256 run_on: Runner = Runner.SLURM, 

257 sbatch_options: list[str] = None, 

258 srun_options: list[str] = None, 

259 parallel_jobs: int = None, 

260 slurm_prepend: str | list[str] | Path = None, 

261 dependencies: list[Run] = None, 

262 log_dir: Path = None, 

263 ) -> None: 

264 """Run the Extractor CLI and write result to the FeatureDataFrame. 

265 

266 Args: 

267 instance_set: The instance set to run the Extractor on. 

268 feature_dataframe: The feature dataframe to write to. 

269 cutoff_time: CPU cutoff time in seconds 

270 feature_group: The feature group to compute. If left empty, 

271 will run on all feature groups. 

272 run_on: The runner to use. 

273 sbatch_options: Additional options to pass to sbatch. 

274 srun_options: Additional options to pass to srun. 

275 parallel_jobs: Number of parallel jobs to run. 

276 slurm_prepend: Slurm script to prepend to the sbatch 

277 dependencies: List of dependencies to add to the job. 

278 log_dir: The directory to write logs to. 

279 """ 

280 instances = ( 

281 instance_set 

282 if isinstance(instance_set, list) 

283 else instance_set.instance_paths 

284 ) 

285 log_dir = Path() if log_dir is None else log_dir 

286 feature_group = f"--feature-group {feature_group} " if feature_group else "" 

287 commands = [ 

288 f"python3 {Extractor.extractor_cli} " 

289 f"--extractor {self.directory} " 

290 f"--instance {instance_path} " 

291 f"--feature-csv {feature_dataframe.csv_filepath} " 

292 f"{feature_group}" 

293 f"--cutoff {cutoff_time} " 

294 f"--log-dir {log_dir}" 

295 for instance_path in instances 

296 ] 

297 

298 job_name = f"Run Extractor {self.name} on {feature_group} for {len(instances)} instances" 

299 import subprocess 

300 

301 run = rrr.add_to_queue( 

302 runner=run_on, 

303 cmd=commands, 

304 name=job_name, 

305 stdout=None if run_on == Runner.LOCAL else subprocess.PIPE, # Print 

306 stderr=None if run_on == Runner.LOCAL else subprocess.PIPE, # Print 

307 base_dir=log_dir, 

308 sbatch_options=sbatch_options, 

309 srun_options=srun_options, 

310 parallel_jobs=parallel_jobs, 

311 prepend=slurm_prepend, 

312 dependencies=dependencies, 

313 ) 

314 if isinstance(run, LocalRun): 

315 print("Waiting for the local calculations to finish.") 

316 run.wait() 

317 for job in run.jobs: 

318 jobs_done = sum(j.status == Status.COMPLETED for j in run.jobs) 

319 print(f"Executing Progress: {jobs_done} out of {len(run.jobs)}") 

320 if jobs_done == len(run.jobs): 

321 break 

322 job.wait() 

323 print("Computing features done!") 

324 else: 

325 print(f"Running {self.name} through Slurm with Job IDs: {run.run_id}") 

326 return run 

327 

328 def get_feature_vector( 

329 self: Extractor, result: Path, runsolver_values: Path = None 

330 ) -> list[str]: 

331 """Extracts feature vector from an output file. 

332 

333 Args: 

334 result: The raw output of the extractor 

335 runsolver_values: The output of runsolver. 

336 

337 Returns: 

338 A list of features. Vector of missing values upon failure. 

339 """ 

340 if ( 

341 result.exists() 

342 and RunSolver.get_status(runsolver_values, None) != SolverStatus.TIMEOUT 

343 ): 

344 feature_values = ast.literal_eval(result.read_text()) 

345 return [str(value) for _, _, value in feature_values] 

346 return [FeatureDataFrame.missing_value] * self.output_dimension