Coverage for src/sparkle/CLI/compute_features.py: 91%

70 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 14:11 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to compute features for instances.""" 

3 

4from __future__ import annotations 

5import sys 

6import argparse 

7from pathlib import Path 

8 

9from runrunner.base import Run, Runner 

10 

11from sparkle.selector import Extractor 

12from sparkle.platform.settings_objects import Settings 

13from sparkle.structures import FeatureDataFrame 

14from sparkle.instance import Instance_Set, InstanceSet 

15 

16 

17from sparkle.CLI.help import global_variables as gv 

18from sparkle.CLI.help import logging as sl 

19from sparkle.CLI.help import argparse_custom as ac 

20from sparkle.CLI.initialise import check_for_initialise 

21from sparkle.CLI.help.nicknames import resolve_instance_name 

22 

23 

24def parser_function() -> argparse.ArgumentParser: 

25 """Define the command line arguments.""" 

26 parser = argparse.ArgumentParser( 

27 description="Sparkle command to Compute features " 

28 "for instances using added extractors " 

29 "and instances." 

30 ) 

31 parser.add_argument( 

32 *ac.RecomputeFeaturesArgument.names, **ac.RecomputeFeaturesArgument.kwargs 

33 ) 

34 # Settings arguments 

35 parser.add_argument(*ac.SettingsFileArgument.names, **ac.SettingsFileArgument.kwargs) 

36 parser.add_argument(*Settings.OPTION_run_on.args, **Settings.OPTION_run_on.kwargs) 

37 return parser 

38 

39 

40def compute_features( 

41 feature_data: Path | FeatureDataFrame, 

42 recompute: bool, 

43 run_on: Runner = Runner.SLURM, 

44) -> list[Run]: 

45 """Compute features for all instance and feature extractor combinations. 

46 

47 A RunRunner run is submitted for the computation of the features. 

48 The results are then stored in the csv file specified by feature_data_csv_path. 

49 

50 Args: 

51 feature_data: Feature Data Frame to use, or path to read it from. 

52 recompute: Specifies if features should be recomputed. 

53 run_on: Runner 

54 On which computer or cluster environment to run the solvers. 

55 Available: Runner.LOCAL, Runner.SLURM. Default: Runner.SLURM 

56 

57 Returns: 

58 The Slurm job or Local job 

59 """ 

60 if isinstance(feature_data, Path): 

61 feature_data = FeatureDataFrame(feature_data) 

62 if recompute: 

63 feature_data.reset_dataframe() 

64 jobs = feature_data.remaining_jobs() 

65 

66 # Lookup all instances to resolve the instance paths later 

67 instances: list[InstanceSet] = [] 

68 for instance_dir in gv.settings().DEFAULT_instance_dir.iterdir(): 

69 if instance_dir.is_dir(): 

70 instances.append(Instance_Set(instance_dir)) 

71 

72 # If there are no jobs, stop 

73 if not jobs: 

74 print( 

75 "No feature computation jobs to run; stopping execution! To recompute " 

76 "feature values use the --recompute flag." 

77 ) 

78 return None 

79 cutoff = gv.settings().extractor_cutoff_time 

80 instance_paths = set() 

81 grouped_job_list: dict[str, dict[str, list[str]]] = {} 

82 

83 # Group the jobs by extractor/feature group 

84 for instance_name, extractor_name, feature_group in jobs: 

85 if extractor_name not in grouped_job_list: 

86 grouped_job_list[extractor_name] = {} 

87 if feature_group not in grouped_job_list[extractor_name]: 

88 grouped_job_list[extractor_name][feature_group] = [] 

89 instance_path = resolve_instance_name(str(instance_name), instances) 

90 grouped_job_list[extractor_name][feature_group].append(instance_path) 

91 

92 sbatch_options = gv.settings().sbatch_settings 

93 slurm_prepend = gv.settings().slurm_job_prepend 

94 srun_options = ["-N1", "-n1"] + sbatch_options 

95 runs = [] 

96 for extractor_name, feature_groups in grouped_job_list.items(): 

97 extractor_path = gv.settings().DEFAULT_extractor_dir / extractor_name 

98 extractor = Extractor(extractor_path) 

99 for feature_group, instance_paths in feature_groups.items(): 

100 run = extractor.run_cli( 

101 instance_paths, 

102 feature_data, 

103 cutoff, 

104 feature_group if extractor.groupwise_computation else None, 

105 run_on, 

106 sbatch_options, 

107 srun_options, 

108 gv.settings().slurm_jobs_in_parallel, 

109 slurm_prepend, 

110 log_dir=sl.caller_log_dir, 

111 ) 

112 runs.append(run) 

113 return runs 

114 

115 

116def main(argv: list[str]) -> None: 

117 """Main function of the compute features command.""" 

118 # Define command line arguments 

119 parser = parser_function() 

120 

121 # Process command line arguments 

122 args = parser.parse_args(argv) 

123 settings = gv.settings(args) 

124 run_on = settings.run_on 

125 

126 # Log command call 

127 sl.log_command(sys.argv, settings.random_state) 

128 check_for_initialise() 

129 

130 # Check if there are any feature extractors registered 

131 if not any([p.is_dir() for p in gv.settings().DEFAULT_extractor_dir.iterdir()]): 

132 print( 

133 "No feature extractors present! Add feature extractors to Sparkle " 

134 "by using the add_feature_extractor command." 

135 ) 

136 sys.exit() 

137 

138 # Start compute features 

139 print("Start computing features ...") 

140 compute_features(settings.DEFAULT_feature_data_path, args.recompute, run_on) 

141 

142 # Write used settings to file 

143 gv.settings().write_used_settings() 

144 sys.exit(0) 

145 

146 

147if __name__ == "__main__": 

148 main(sys.argv[1:])