Coverage for sparkle/tools/configspace.py: 61%

241 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-03 10:42 +0000

1"""Extensions of the ConfigSpace lib.""" 

2from __future__ import annotations 

3from typing_extensions import override, Iterable 

4import ast 

5 

6import numpy as np 

7from typing import Any 

8from ConfigSpace import ConfigurationSpace 

9from ConfigSpace.hyperparameters import Hyperparameter 

10from ConfigSpace.types import Array, Mask, f64 

11 

12from ConfigSpace.conditions import ( 

13 Condition, 

14 AndConjunction, 

15 OrConjunction, 

16 EqualsCondition, 

17 GreaterThanCondition, 

18 InCondition, 

19 LessThanCondition, 

20 NotEqualsCondition 

21) 

22 

23from ConfigSpace.forbidden import ( 

24 ForbiddenGreaterThanRelation, 

25 ForbiddenLessThanRelation, 

26 ForbiddenClause, 

27 ForbiddenConjunction, 

28 ForbiddenRelation, 

29 ForbiddenInClause, 

30 ForbiddenEqualsClause, 

31 ForbiddenAndConjunction 

32) 

33 

34_SENTINEL = object() 

35 

36 

37def expression_to_configspace( 

38 expression: str | ast.Module, 

39 configspace: ConfigurationSpace, 

40 target_parameter: Hyperparameter = None) -> ForbiddenClause | Condition: 

41 """Convert a logic expression to ConfigSpace expression. 

42 

43 Args: 

44 expression: The expression to convert. 

45 configspace: The ConfigSpace to use. 

46 target_parameter: For conditions, will parse the expression as a condition 

47 underwhich the parameter will be active. 

48 """ 

49 if isinstance(expression, str): 

50 try: 

51 expression = ast.parse(expression) 

52 except Exception as e: 

53 raise ValueError(f"Could not parse expression: '{expression}', {e}") 

54 if isinstance(expression, ast.Module): 

55 expression = expression.body[0] 

56 return recursive_conversion(expression, configspace, 

57 target_parameter=target_parameter) 

58 

59 

60def recursive_conversion( 

61 item: ast.mod, 

62 configspace: ConfigurationSpace, 

63 target_parameter: Hyperparameter = None) -> ForbiddenClause | Condition: 

64 """Recursively parse the AST tree to a ConfigSpace expression. 

65 

66 Args: 

67 item: The item to parse. 

68 configspace: The ConfigSpace to use. 

69 target_parameter: For conditions, will parse the expression as a condition 

70 underwhich the parameter will be active. 

71 

72 Returns: 

73 A ConfigSpace expression 

74 """ 

75 if isinstance(item, list): 

76 if len(item) > 1: 

77 raise ValueError(f"Can not parse list of elements: {item}.") 

78 item = item[0] 

79 if isinstance(item, ast.Expr): 

80 return recursive_conversion(item.value, configspace, target_parameter) 

81 if isinstance(item, ast.Name): # Convert to hyperparameter 

82 hp = configspace.get(item.id) 

83 return hp if hp is not None else item.id 

84 if isinstance(item, ast.Constant): 

85 return item.value 

86 if (isinstance(item, ast.Tuple) 

87 or isinstance(item, ast.Set) or isinstance(item, ast.List)): 

88 values = [] 

89 for v in item.elts: 

90 if isinstance(v, ast.Constant): 

91 values.append(v.value) 

92 elif isinstance(v, ast.Name): # Check if its a parameter 

93 if v.id in list(configspace.values()): 

94 raise ValueError("Only constants allowed in tuples. " 

95 f"Found: {item.elts}") 

96 values.append(v.id) # String value was interpreted as parameter 

97 return values 

98 if isinstance(item, ast.BinOp): 

99 raise NotImplementedError("Binary operations not supported by ConfigSpace.") 

100 if isinstance(item, ast.BoolOp): 

101 values = [recursive_conversion(v, configspace, target_parameter) 

102 for v in item.values] 

103 if isinstance(item.op, ast.Or): 

104 if target_parameter: 

105 return OrConjunction(*values) 

106 return ForbiddenOrConjunction(*values) 

107 elif isinstance(item.op, ast.And): 

108 if target_parameter: 

109 return AndConjunction(*values) 

110 return ForbiddenAndConjunction(*values) 

111 else: 

112 raise ValueError(f"Unknown boolean operator: {item.op}") 

113 if isinstance(item, ast.Compare): 

114 if len(item.ops) > 1: 

115 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}") 

116 left = recursive_conversion(item.left, configspace, target_parameter) 

117 right = recursive_conversion(item.comparators, configspace, target_parameter) 

118 operator = item.ops[0] 

119 if isinstance(left, Hyperparameter): # Convert to HP type 

120 if isinstance(right, Iterable) and not isinstance(right, str): 

121 right = [type(left.default_value)(v) for v in right] 

122 if len(right) == 1 and not isinstance(operator, ast.In): 

123 right = right[0] 

124 elif isinstance(right, int): 

125 right = type(left.default_value)(right) 

126 

127 if isinstance(operator, ast.Lt): 

128 if target_parameter: 

129 return LessThanCondition(target_parameter, left, right) 

130 return ForbiddenLessThanRelation(left=left, right=right) 

131 if isinstance(operator, ast.LtE): 

132 if target_parameter: 

133 raise ValueError("LessThanEquals not supported for conditions.") 

134 return ForbiddenLessThanEqualsRelation(left=left, right=right) 

135 if isinstance(operator, ast.Gt): 

136 if target_parameter: 

137 return GreaterThanCondition(target_parameter, left, right) 

138 return ForbiddenGreaterThanRelation(left=left, right=right) 

139 if isinstance(operator, ast.GtE): 

140 if target_parameter: 

141 raise ValueError("GreaterThanEquals not supported for conditions.") 

142 return ForbiddenGreaterThanEqualsRelation(left=left, right=right) 

143 if isinstance(operator, ast.Eq): 

144 if target_parameter: 

145 return EqualsCondition(target_parameter, left, right) 

146 return ForbiddenEqualsClause(hyperparameter=left, value=right) 

147 if isinstance(operator, ast.In): 

148 if target_parameter: 

149 return InCondition(target_parameter, left, right) 

150 return ForbiddenInClause(hyperparameter=left, values=right) 

151 if isinstance(operator, ast.NotEq): 

152 if target_parameter: 

153 return NotEqualsCondition(target_parameter, left, right) 

154 raise ValueError("NotEq operator not supported for ForbiddenClauses.") 

155 # The following classes do not (yet?) exist in configspace 

156 if isinstance(operator, ast.NotIn): 

157 raise ValueError("NotIn operator not supported for ForbiddenClauses.") 

158 if isinstance(operator, ast.Is): 

159 raise NotImplementedError("Is operator not supported.") 

160 if isinstance(operator, ast.IsNot): 

161 raise NotImplementedError("IsNot operator not supported.") 

162 raise ValueError(f"Unsupported type: {item}") 

163 

164 

165class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation): 

166 """A ForbiddenLessThanEquals relation between two hyperparameters.""" 

167 

168 _RELATION_STR = "LESSEQUAL" 

169 

170 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str: 

171 """Return a string representation of the ForbiddenLessThanEqualsRelation.""" 

172 return f"Forbidden: {self.left.name} <= {self.right.name}" 

173 

174 @override 

175 def is_forbidden_value(self: ForbiddenLessThanEqualsRelation, 

176 values: dict[str, Any]) -> bool: 

177 """Check if the value is forbidden.""" 

178 # Relation is always evaluated against actual value and not vector rep 

179 left = values.get(self.left.name, _SENTINEL) 

180 if left is _SENTINEL: 

181 return False 

182 

183 right = values.get(self.right.name, _SENTINEL) 

184 if right is _SENTINEL: 

185 return False 

186 

187 return left <= right # type: ignore 

188 

189 @override 

190 def is_forbidden_vector(self: ForbiddenLessThanEqualsRelation, 

191 vector: Array[f64]) -> bool: 

192 """Check if the vector is forbidden.""" 

193 # Relation is always evaluated against actual value and not vector rep 

194 left: f64 = vector[self.vector_ids[0]] # type: ignore 

195 right: f64 = vector[self.vector_ids[1]] # type: ignore 

196 if np.isnan(left) or np.isnan(right): 

197 return False 

198 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore 

199 

200 @override 

201 def is_forbidden_vector_array(self: ForbiddenLessThanEqualsRelation, 

202 arr: Array[f64]) -> Mask: 

203 """Check if the vector array is forbidden.""" 

204 left = arr[self.vector_ids[0]] 

205 right = arr[self.vector_ids[1]] 

206 valid = ~(np.isnan(left) | np.isnan(right)) 

207 out = np.zeros_like(valid) 

208 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid]) 

209 return out 

210 

211 

212class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation): 

213 """A ForbiddenGreaterThanEquals relation between two hyperparameters.""" 

214 

215 _RELATION_STR = "GREATEREQUAL" 

216 

217 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str: 

218 """Return a string representation of the ForbiddenGreaterThanEqualsRelation.""" 

219 return f"Forbidden: {self.left.name} >= {self.right.name}" 

220 

221 @override 

222 def is_forbidden_value(self: ForbiddenGreaterThanEqualsRelation, 

223 values: dict[str, Any]) -> bool: 

224 """Check if the value is forbidden.""" 

225 left = values.get(self.left.name, _SENTINEL) 

226 if left is _SENTINEL: 

227 return False 

228 

229 right = values.get(self.right.name, _SENTINEL) 

230 if right is _SENTINEL: 

231 return False 

232 

233 return left >= right # type: ignore 

234 

235 @override 

236 def is_forbidden_vector(self: ForbiddenGreaterThanEqualsRelation, 

237 vector: Array[f64]) -> bool: 

238 """Check if the vector is forbidden.""" 

239 # Relation is always evaluated against actual value and not vector rep 

240 left: f64 = vector[self.vector_ids[0]] # type: ignore 

241 right: f64 = vector[self.vector_ids[1]] # type: ignore 

242 if np.isnan(left) or np.isnan(right): 

243 return False 

244 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore 

245 

246 @override 

247 def is_forbidden_vector_array(self: ForbiddenGreaterThanEqualsRelation, 

248 arr: Array[f64]) -> Mask: 

249 """Check if the vector array is forbidden.""" 

250 left = arr[self.vector_ids[0]] 

251 right = arr[self.vector_ids[1]] 

252 valid = ~(np.isnan(left) | np.isnan(right)) 

253 out = np.zeros_like(valid) 

254 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid]) 

255 return out 

256 

257 

258class ForbiddenGreaterThanClause(ForbiddenEqualsClause): 

259 """A ForbiddenGreaterThanClause. 

260 

261 It forbids a value from the value range of a hyperparameter to be 

262 *greater than* `value`. 

263 

264 Forbids the value of the hyperparameter *a* to be greater than 2 

265 

266 Args: 

267 hyperparameter: Methods on which a restriction will be made 

268 value: forbidden value 

269 """ 

270 

271 def __repr__(self: ForbiddenGreaterThanClause) -> str: 

272 """Return a string representation of the ForbiddenGreaterThanClause.""" 

273 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}" 

274 

275 @override 

276 def is_forbidden_value(self: ForbiddenGreaterThanClause, 

277 values: dict[str, Any]) -> bool: 

278 """Check if the value is forbidden.""" 

279 return ( # type: ignore 

280 values.get(self.hyperparameter.name, _SENTINEL) > self.value 

281 ) 

282 

283 @override 

284 def is_forbidden_vector(self: ForbiddenGreaterThanClause, 

285 vector: Array[f64]) -> bool: 

286 """Check if the vector is forbidden.""" 

287 return vector[self.vector_id] > self.vector_value # type: ignore 

288 

289 @override 

290 def is_forbidden_vector_array(self: ForbiddenGreaterThanClause, 

291 arr: Array[f64]) -> Mask: 

292 """Check if the vector array is forbidden.""" 

293 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

294 

295 @override 

296 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]: 

297 """Convert the ForbiddenGreaterThanClause to a dictionary.""" 

298 return { 

299 "name": self.hyperparameter.name, 

300 "type": "GREATER", 

301 "value": self.value, 

302 } 

303 

304 

305class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause): 

306 """A ForbiddenGreaterEqualsClause. 

307 

308 It forbids a value from the value range of a hyperparameter to be 

309 *greater or equal to* `value`. 

310 

311 Forbids the value of the hyperparameter *a* to be greater or equal to 2 

312 

313 Args: 

314 hyperparameter: Methods on which a restriction will be made 

315 value: forbidden value 

316 """ 

317 

318 def __repr__(self: ForbiddenGreaterEqualsClause) -> str: 

319 """Return a string representation of the ForbiddenGreaterEqualsClause.""" 

320 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}" 

321 

322 @override 

323 def is_forbidden_value(self: ForbiddenGreaterEqualsClause, 

324 values: dict[str, Any]) -> bool: 

325 """Check if the value is forbidden.""" 

326 return ( # type: ignore 

327 values.get(self.hyperparameter.name, _SENTINEL) >= self.value 

328 ) 

329 

330 @override 

331 def is_forbidden_vector(self: ForbiddenGreaterEqualsClause, 

332 vector: Array[f64]) -> bool: 

333 """Check if the vector is forbidden.""" 

334 return vector[self.vector_id] >= self.vector_value # type: ignore 

335 

336 @override 

337 def is_forbidden_vector_array(self: ForbiddenGreaterEqualsClause, 

338 arr: Array[f64]) -> Mask: 

339 """Check if the vector array is forbidden.""" 

340 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

341 

342 @override 

343 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]: 

344 """Convert the ForbiddenGreaterEqualsClause to a dictionary.""" 

345 return { 

346 "name": self.hyperparameter.name, 

347 "type": "GREATEREQUAL", 

348 "value": self.value, 

349 } 

350 

351 

352class ForbiddenLessThanClause(ForbiddenEqualsClause): 

353 """A ForbiddenLessThanClause. 

354 

355 It forbids a value from the value range of a hyperparameter to be 

356 *less than* `value`. 

357 

358 Args: 

359 hyperparameter: Methods on which a restriction will be made 

360 value: forbidden value 

361 """ 

362 

363 def __repr__(self: ForbiddenLessThanClause) -> str: 

364 """Return a string representation of the ForbiddenLessThanClause.""" 

365 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}" 

366 

367 @override 

368 def is_forbidden_value(self: ForbiddenLessThanClause, 

369 values: dict[str, Any]) -> bool: 

370 """Check if the value is forbidden.""" 

371 return ( # type: ignore 

372 values.get(self.hyperparameter.name, _SENTINEL) < self.value 

373 ) 

374 

375 @override 

376 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool: 

377 """Check if the vector is forbidden.""" 

378 return vector[self.vector_id] < self.vector_value # type: ignore 

379 

380 @override 

381 def is_forbidden_vector_array(self: ForbiddenLessThanClause, 

382 arr: Array[f64]) -> Mask: 

383 """Check if the vector array is forbidden.""" 

384 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

385 

386 @override 

387 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]: 

388 """Convert the ForbiddenLessThanClause to a dictionary.""" 

389 return { 

390 "name": self.hyperparameter.name, 

391 "type": "LESS", 

392 "value": self.value, 

393 } 

394 

395 

396class ForbiddenLessEqualsClause(ForbiddenEqualsClause): 

397 """A ForbiddenLessEqualsClause. 

398 

399 It forbids a value from the value range of a hyperparameter to be 

400 *less or equal to* `value`. 

401 

402 Args: 

403 hyperparameter: Methods on which a restriction will be made 

404 value: forbidden value 

405 """ 

406 

407 def __repr__(self: ForbiddenLessEqualsClause) -> str: 

408 """Return a string representation of the ForbiddenLessEqualsClause.""" 

409 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}" 

410 

411 @override 

412 def is_forbidden_value(self: ForbiddenLessEqualsClause, 

413 values: dict[str, Any]) -> bool: 

414 """Check if the value is forbidden.""" 

415 return ( # type: ignore 

416 values.get(self.hyperparameter.name, _SENTINEL) <= self.value 

417 ) 

418 

419 @override 

420 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool: 

421 """Check if the vector is forbidden.""" 

422 return vector[self.vector_id] <= self.vector_value # type: ignore 

423 

424 @override 

425 def is_forbidden_vector_array(self: ForbiddenLessEqualsClause, 

426 arr: Array[f64]) -> Mask: 

427 """Check if the vector array is forbidden.""" 

428 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

429 

430 @override 

431 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]: 

432 """Convert the ForbiddenLessEqualsClause to a dictionary.""" 

433 return { 

434 "name": self.hyperparameter.name, 

435 "type": "LESSEQUAL", 

436 "value": self.value, 

437 } 

438 

439 

440class ForbiddenOrConjunction(ForbiddenConjunction): 

441 """A ForbiddenOrConjunction. 

442 

443 The ForbiddenOrConjunction combines forbidden-clauses, which allows to 

444 build powerful constraints. 

445 

446 ```python exec="true", source="material-block" result="python" 

447 from ConfigSpace import ( 

448 ConfigurationSpace, 

449 ForbiddenEqualsClause, 

450 ForbiddenInClause, 

451 ) 

452 from sparkle.tools.configspace import ForbiddenOrConjunction 

453 

454 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]}) 

455 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2) 

456 forbidden_clause_b = ForbiddenInClause(cs["b"], [2]) 

457 

458 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b) 

459 

460 cs.add(forbidden_clause) 

461 print(cs) 

462 ``` 

463 

464 Args: 

465 *args: forbidden clauses, which should be combined 

466 """ 

467 

468 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...] 

469 """Components of the conjunction.""" 

470 

471 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...] 

472 """Descendant literal clauses of the conjunction. 

473 

474 These are the base forbidden clauses/relations that are part of conjunctions. 

475 

476 !!! note 

477 

478 This will only store a unique set of the descendant clauses, no duplicates. 

479 """ 

480 

481 def __repr__(self: ForbiddenOrConjunction) -> str: 

482 """Return a string representation of the ForbiddenOrConjunction.""" 

483 return "(" + " || ".join([str(c) for c in self.components]) + ")" 

484 

485 @override 

486 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool: 

487 """Check if the value is forbidden.""" 

488 return any([forbidden.is_forbidden_value(values) 

489 for forbidden in self.components]) 

490 

491 @override 

492 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool: 

493 """Check if the vector is forbidden.""" 

494 return any( 

495 forbidden.is_forbidden_vector(vector) for forbidden in self.components 

496 ) 

497 

498 @override 

499 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask: 

500 """Check if the vector array is forbidden.""" 

501 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_) 

502 for forbidden in self.components: 

503 forbidden_mask |= forbidden.is_forbidden_vector_array(arr) 

504 

505 return forbidden_mask 

506 

507 @override 

508 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]: 

509 """Convert the ForbiddenOrConjunction to a dictionary.""" 

510 return { 

511 "type": "OR", 

512 "clauses": [component.to_dict() for component in self.components], 

513 }