Coverage for src / sparkle / tools / configspace.py: 100%

243 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 15:31 +0000

1"""Extensions of the ConfigSpace lib.""" 

2 

3from __future__ import annotations 

4from typing_extensions import override, Iterable 

5import ast 

6 

7import numpy as np 

8from typing import Any 

9from ConfigSpace import ConfigurationSpace 

10from ConfigSpace.hyperparameters import Hyperparameter 

11from ConfigSpace.types import Array, Mask, f64 

12 

13from ConfigSpace.conditions import ( 

14 Condition, 

15 AndConjunction, 

16 OrConjunction, 

17 EqualsCondition, 

18 GreaterThanCondition, 

19 InCondition, 

20 LessThanCondition, 

21 NotEqualsCondition, 

22) 

23 

24from ConfigSpace.forbidden import ( 

25 ForbiddenGreaterThanRelation, 

26 ForbiddenLessThanRelation, 

27 ForbiddenClause, 

28 ForbiddenConjunction, 

29 ForbiddenRelation, 

30 ForbiddenInClause, 

31 ForbiddenEqualsClause, 

32 ForbiddenAndConjunction, 

33) 

34 

35_SENTINEL = object() 

36 

37 

38def expression_to_configspace( 

39 expression: str | ast.Module, 

40 configspace: ConfigurationSpace, 

41 target_parameter: Hyperparameter = None, 

42) -> ForbiddenClause | Condition: 

43 """Convert a logic expression to ConfigSpace expression. 

44 

45 Args: 

46 expression: The expression to convert. 

47 configspace: The ConfigSpace to use. 

48 target_parameter: For conditions, will parse the expression as a condition 

49 underwhich the parameter will be active. 

50 """ 

51 if isinstance(expression, str): 

52 try: 

53 expression = ast.parse(expression) 

54 except Exception as e: 

55 raise ValueError(f"Could not parse expression: '{expression}', {e}") 

56 if isinstance(expression, ast.Module): 

57 expression = expression.body[0] 

58 return recursive_conversion( 

59 expression, configspace, target_parameter=target_parameter 

60 ) 

61 

62 

63def recursive_conversion( 

64 item: ast.mod, 

65 configspace: ConfigurationSpace, 

66 target_parameter: Hyperparameter = None, 

67) -> ForbiddenClause | Condition: 

68 """Recursively parse the AST tree to a ConfigSpace expression. 

69 

70 Args: 

71 item: The item to parse. 

72 configspace: The ConfigSpace to use. 

73 target_parameter: For conditions, will parse the expression as a condition 

74 underwhich the parameter will be active. 

75 

76 Returns: 

77 A ConfigSpace expression 

78 """ 

79 if isinstance(item, list): 

80 if len(item) > 1: 

81 raise ValueError(f"Can not parse list of elements: {item}.") 

82 item = item[0] 

83 if isinstance(item, ast.Expr): 

84 return recursive_conversion(item.value, configspace, target_parameter) 

85 if isinstance(item, ast.Name): # Convert to hyperparameter 

86 hp = configspace.get(item.id) 

87 return hp if hp is not None else item.id 

88 if isinstance(item, ast.Constant): 

89 return item.value 

90 if ( 

91 isinstance(item, ast.Tuple) 

92 or isinstance(item, ast.Set) 

93 or isinstance(item, ast.List) 

94 ): 

95 values = [] 

96 for v in item.elts: 

97 if isinstance(v, ast.Constant): 

98 values.append(v.value) 

99 elif isinstance(v, ast.Name): # Check if its a parameter 

100 if v.id in list(configspace.values()): 

101 raise ValueError( 

102 f"Only constants allowed in tuples. Found: {item.elts}" 

103 ) 

104 values.append(v.id) # String value was interpreted as parameter 

105 return values 

106 if isinstance(item, ast.BinOp): 

107 raise NotImplementedError("Binary operations not supported by ConfigSpace.") 

108 if isinstance(item, ast.BoolOp): 

109 values = [ 

110 recursive_conversion(v, configspace, target_parameter) for v in item.values 

111 ] 

112 if isinstance(item.op, ast.Or): 

113 if target_parameter: 

114 return OrConjunction(*values) 

115 return ForbiddenOrConjunction(*values) 

116 elif isinstance(item.op, ast.And): 

117 if target_parameter: 

118 return AndConjunction(*values) 

119 return ForbiddenAndConjunction(*values) 

120 else: 

121 raise ValueError(f"Unknown boolean operator: {item.op}") 

122 if isinstance(item, ast.Compare): 

123 if len(item.ops) > 1: 

124 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}") 

125 left = recursive_conversion(item.left, configspace, target_parameter) 

126 right = recursive_conversion(item.comparators, configspace, target_parameter) 

127 operator = item.ops[0] 

128 if isinstance(left, Hyperparameter): # Convert to HP type 

129 # Handle special case for 'in' operator with single value 

130 # "text" -> ["text"], 5 -> [5] 

131 # We want to ensure that the right side is a list for 'in' operator 

132 # So that we can check membership correctly: 

133 # Not: hp in "hp" but: hp in ["hp", "hp2", ...] 

134 if isinstance(operator, ast.In) and ( 

135 not isinstance(right, Iterable) or isinstance(right, str) 

136 ): 

137 right = [right] 

138 if isinstance(right, Iterable) and not isinstance(right, str): 

139 right = [type(left.default_value)(v) for v in right] 

140 if len(right) == 1 and not isinstance(operator, ast.In): 

141 right = right[0] 

142 elif isinstance(right, int): 

143 right = type(left.default_value)(right) 

144 

145 if isinstance(operator, ast.Lt): 

146 if target_parameter: 

147 return LessThanCondition(target_parameter, left, right) 

148 return ForbiddenLessThanRelation(left=left, right=right) 

149 if isinstance(operator, ast.LtE): 

150 if target_parameter: 

151 raise ValueError("LessThanEquals not supported for conditions.") 

152 return ForbiddenLessThanEqualsRelation(left=left, right=right) 

153 if isinstance(operator, ast.Gt): 

154 if target_parameter: 

155 return GreaterThanCondition(target_parameter, left, right) 

156 return ForbiddenGreaterThanRelation(left=left, right=right) 

157 if isinstance(operator, ast.GtE): 

158 if target_parameter: 

159 raise ValueError("GreaterThanEquals not supported for conditions.") 

160 return ForbiddenGreaterThanEqualsRelation(left=left, right=right) 

161 if isinstance(operator, ast.Eq): 

162 if target_parameter: 

163 return EqualsCondition(target_parameter, left, right) 

164 return ForbiddenEqualsClause(hyperparameter=left, value=right) 

165 if isinstance(operator, ast.In): 

166 if target_parameter: 

167 return InCondition(target_parameter, left, right) 

168 return ForbiddenInClause(hyperparameter=left, values=right) 

169 if isinstance(operator, ast.NotEq): 

170 if target_parameter: 

171 return NotEqualsCondition(target_parameter, left, right) 

172 raise ValueError("NotEq operator not supported for ForbiddenClauses.") 

173 # The following classes do not (yet?) exist in configspace 

174 if isinstance(operator, ast.NotIn): 

175 raise ValueError("NotIn operator not supported for ForbiddenClauses.") 

176 if isinstance(operator, ast.Is): 

177 raise NotImplementedError("Is operator not supported.") 

178 if isinstance(operator, ast.IsNot): 

179 raise NotImplementedError("IsNot operator not supported.") 

180 raise ValueError(f"Unsupported type: {item}") 

181 

182 

183class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation): 

184 """A ForbiddenLessThanEquals relation between two hyperparameters.""" 

185 

186 _RELATION_STR = "LESSEQUAL" 

187 

188 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str: 

189 """Return a string representation of the ForbiddenLessThanEqualsRelation.""" 

190 return f"Forbidden: {self.left.name} <= {self.right.name}" 

191 

192 @override 

193 def is_forbidden_value( 

194 self: ForbiddenLessThanEqualsRelation, values: dict[str, Any] 

195 ) -> bool: 

196 """Check if the value is forbidden.""" 

197 # Relation is always evaluated against actual value and not vector rep 

198 left = values.get(self.left.name, _SENTINEL) 

199 if left is _SENTINEL: 

200 return False 

201 

202 right = values.get(self.right.name, _SENTINEL) 

203 if right is _SENTINEL: 

204 return False 

205 

206 return left <= right # type: ignore 

207 

208 @override 

209 def is_forbidden_vector( 

210 self: ForbiddenLessThanEqualsRelation, vector: Array[f64] 

211 ) -> bool: 

212 """Check if the vector is forbidden.""" 

213 # Relation is always evaluated against actual value and not vector rep 

214 left: f64 = vector[self.vector_ids[0]] # type: ignore 

215 right: f64 = vector[self.vector_ids[1]] # type: ignore 

216 if np.isnan(left) or np.isnan(right): 

217 return False 

218 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore 

219 

220 @override 

221 def is_forbidden_vector_array( 

222 self: ForbiddenLessThanEqualsRelation, arr: Array[f64] 

223 ) -> Mask: 

224 """Check if the vector array is forbidden.""" 

225 left = arr[self.vector_ids[0]] 

226 right = arr[self.vector_ids[1]] 

227 valid = ~(np.isnan(left) | np.isnan(right)) 

228 out = np.zeros_like(valid) 

229 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid]) 

230 return out 

231 

232 

233class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation): 

234 """A ForbiddenGreaterThanEquals relation between two hyperparameters.""" 

235 

236 _RELATION_STR = "GREATEREQUAL" 

237 

238 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str: 

239 """Return a string representation of the ForbiddenGreaterThanEqualsRelation.""" 

240 return f"Forbidden: {self.left.name} >= {self.right.name}" 

241 

242 @override 

243 def is_forbidden_value( 

244 self: ForbiddenGreaterThanEqualsRelation, values: dict[str, Any] 

245 ) -> bool: 

246 """Check if the value is forbidden.""" 

247 left = values.get(self.left.name, _SENTINEL) 

248 if left is _SENTINEL: 

249 return False 

250 

251 right = values.get(self.right.name, _SENTINEL) 

252 if right is _SENTINEL: 

253 return False 

254 

255 return left >= right # type: ignore 

256 

257 @override 

258 def is_forbidden_vector( 

259 self: ForbiddenGreaterThanEqualsRelation, vector: Array[f64] 

260 ) -> bool: 

261 """Check if the vector is forbidden.""" 

262 # Relation is always evaluated against actual value and not vector rep 

263 left: f64 = vector[self.vector_ids[0]] # type: ignore 

264 right: f64 = vector[self.vector_ids[1]] # type: ignore 

265 if np.isnan(left) or np.isnan(right): 

266 return False 

267 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore 

268 

269 @override 

270 def is_forbidden_vector_array( 

271 self: ForbiddenGreaterThanEqualsRelation, arr: Array[f64] 

272 ) -> Mask: 

273 """Check if the vector array is forbidden.""" 

274 left = arr[self.vector_ids[0]] 

275 right = arr[self.vector_ids[1]] 

276 valid = ~(np.isnan(left) | np.isnan(right)) 

277 out = np.zeros_like(valid) 

278 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid]) 

279 return out 

280 

281 

282class ForbiddenGreaterThanClause(ForbiddenEqualsClause): 

283 """A ForbiddenGreaterThanClause. 

284 

285 It forbids a value from the value range of a hyperparameter to be 

286 *greater than* `value`. 

287 

288 Forbids the value of the hyperparameter *a* to be greater than 2 

289 

290 Args: 

291 hyperparameter: Methods on which a restriction will be made 

292 value: forbidden value 

293 """ 

294 

295 def __repr__(self: ForbiddenGreaterThanClause) -> str: 

296 """Return a string representation of the ForbiddenGreaterThanClause.""" 

297 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}" 

298 

299 @override 

300 def is_forbidden_value( 

301 self: ForbiddenGreaterThanClause, values: dict[str, Any] 

302 ) -> bool: 

303 """Check if the value is forbidden.""" 

304 return ( # type: ignore 

305 values.get(self.hyperparameter.name, _SENTINEL) > self.value 

306 ) 

307 

308 @override 

309 def is_forbidden_vector( 

310 self: ForbiddenGreaterThanClause, vector: Array[f64] 

311 ) -> bool: 

312 """Check if the vector is forbidden.""" 

313 return vector[self.vector_id] > self.vector_value # type: ignore 

314 

315 @override 

316 def is_forbidden_vector_array( 

317 self: ForbiddenGreaterThanClause, arr: Array[f64] 

318 ) -> Mask: 

319 """Check if the vector array is forbidden.""" 

320 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

321 

322 @override 

323 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]: 

324 """Convert the ForbiddenGreaterThanClause to a dictionary.""" 

325 return { 

326 "name": self.hyperparameter.name, 

327 "type": "GREATER", 

328 "value": self.value, 

329 } 

330 

331 

332class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause): 

333 """A ForbiddenGreaterEqualsClause. 

334 

335 It forbids a value from the value range of a hyperparameter to be 

336 *greater or equal to* `value`. 

337 

338 Forbids the value of the hyperparameter *a* to be greater or equal to 2 

339 

340 Args: 

341 hyperparameter: Methods on which a restriction will be made 

342 value: forbidden value 

343 """ 

344 

345 def __repr__(self: ForbiddenGreaterEqualsClause) -> str: 

346 """Return a string representation of the ForbiddenGreaterEqualsClause.""" 

347 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}" 

348 

349 @override 

350 def is_forbidden_value( 

351 self: ForbiddenGreaterEqualsClause, values: dict[str, Any] 

352 ) -> bool: 

353 """Check if the value is forbidden.""" 

354 return ( # type: ignore 

355 values.get(self.hyperparameter.name, _SENTINEL) >= self.value 

356 ) 

357 

358 @override 

359 def is_forbidden_vector( 

360 self: ForbiddenGreaterEqualsClause, vector: Array[f64] 

361 ) -> bool: 

362 """Check if the vector is forbidden.""" 

363 return vector[self.vector_id] >= self.vector_value # type: ignore 

364 

365 @override 

366 def is_forbidden_vector_array( 

367 self: ForbiddenGreaterEqualsClause, arr: Array[f64] 

368 ) -> Mask: 

369 """Check if the vector array is forbidden.""" 

370 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

371 

372 @override 

373 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]: 

374 """Convert the ForbiddenGreaterEqualsClause to a dictionary.""" 

375 return { 

376 "name": self.hyperparameter.name, 

377 "type": "GREATEREQUAL", 

378 "value": self.value, 

379 } 

380 

381 

382class ForbiddenLessThanClause(ForbiddenEqualsClause): 

383 """A ForbiddenLessThanClause. 

384 

385 It forbids a value from the value range of a hyperparameter to be 

386 *less than* `value`. 

387 

388 Args: 

389 hyperparameter: Methods on which a restriction will be made 

390 value: forbidden value 

391 """ 

392 

393 def __repr__(self: ForbiddenLessThanClause) -> str: 

394 """Return a string representation of the ForbiddenLessThanClause.""" 

395 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}" 

396 

397 @override 

398 def is_forbidden_value( 

399 self: ForbiddenLessThanClause, values: dict[str, Any] 

400 ) -> bool: 

401 """Check if the value is forbidden.""" 

402 return ( # type: ignore 

403 values.get(self.hyperparameter.name, _SENTINEL) < self.value 

404 ) 

405 

406 @override 

407 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool: 

408 """Check if the vector is forbidden.""" 

409 return vector[self.vector_id] < self.vector_value # type: ignore 

410 

411 @override 

412 def is_forbidden_vector_array( 

413 self: ForbiddenLessThanClause, arr: Array[f64] 

414 ) -> Mask: 

415 """Check if the vector array is forbidden.""" 

416 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

417 

418 @override 

419 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]: 

420 """Convert the ForbiddenLessThanClause to a dictionary.""" 

421 return { 

422 "name": self.hyperparameter.name, 

423 "type": "LESS", 

424 "value": self.value, 

425 } 

426 

427 

428class ForbiddenLessEqualsClause(ForbiddenEqualsClause): 

429 """A ForbiddenLessEqualsClause. 

430 

431 It forbids a value from the value range of a hyperparameter to be 

432 *less or equal to* `value`. 

433 

434 Args: 

435 hyperparameter: Methods on which a restriction will be made 

436 value: forbidden value 

437 """ 

438 

439 def __repr__(self: ForbiddenLessEqualsClause) -> str: 

440 """Return a string representation of the ForbiddenLessEqualsClause.""" 

441 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}" 

442 

443 @override 

444 def is_forbidden_value( 

445 self: ForbiddenLessEqualsClause, values: dict[str, Any] 

446 ) -> bool: 

447 """Check if the value is forbidden.""" 

448 return ( # type: ignore 

449 values.get(self.hyperparameter.name, _SENTINEL) <= self.value 

450 ) 

451 

452 @override 

453 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool: 

454 """Check if the vector is forbidden.""" 

455 return vector[self.vector_id] <= self.vector_value # type: ignore 

456 

457 @override 

458 def is_forbidden_vector_array( 

459 self: ForbiddenLessEqualsClause, arr: Array[f64] 

460 ) -> Mask: 

461 """Check if the vector array is forbidden.""" 

462 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

463 

464 @override 

465 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]: 

466 """Convert the ForbiddenLessEqualsClause to a dictionary.""" 

467 return { 

468 "name": self.hyperparameter.name, 

469 "type": "LESSEQUAL", 

470 "value": self.value, 

471 } 

472 

473 

474class ForbiddenOrConjunction(ForbiddenConjunction): 

475 """A ForbiddenOrConjunction. 

476 

477 The ForbiddenOrConjunction combines forbidden-clauses, which allows to 

478 build powerful constraints. 

479 

480 ```python exec="true", source="material-block" result="python" 

481 from ConfigSpace import ( 

482 ConfigurationSpace, 

483 ForbiddenEqualsClause, 

484 ForbiddenInClause, 

485 ) 

486 from sparkle.tools.configspace import ForbiddenOrConjunction 

487 

488 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]}) 

489 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2) 

490 forbidden_clause_b = ForbiddenInClause(cs["b"], [2]) 

491 

492 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b) 

493 

494 cs.add(forbidden_clause) 

495 print(cs) 

496 ``` 

497 

498 Args: 

499 *args: forbidden clauses, which should be combined 

500 """ 

501 

502 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...] 

503 """Components of the conjunction.""" 

504 

505 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...] 

506 """Descendant literal clauses of the conjunction. 

507 

508 These are the base forbidden clauses/relations that are part of conjunctions. 

509 

510 !!! note 

511 

512 This will only store a unique set of the descendant clauses, no duplicates. 

513 """ 

514 

515 def __repr__(self: ForbiddenOrConjunction) -> str: 

516 """Return a string representation of the ForbiddenOrConjunction.""" 

517 return "(" + " || ".join([str(c) for c in self.components]) + ")" 

518 

519 @override 

520 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool: 

521 """Check if the value is forbidden.""" 

522 return any( 

523 [forbidden.is_forbidden_value(values) for forbidden in self.components] 

524 ) 

525 

526 @override 

527 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool: 

528 """Check if the vector is forbidden.""" 

529 return any( 

530 forbidden.is_forbidden_vector(vector) for forbidden in self.components 

531 ) 

532 

533 @override 

534 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask: 

535 """Check if the vector array is forbidden.""" 

536 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_) 

537 for forbidden in self.components: 

538 forbidden_mask |= forbidden.is_forbidden_vector_array(arr) 

539 

540 return forbidden_mask 

541 

542 @override 

543 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]: 

544 """Convert the ForbiddenOrConjunction to a dictionary.""" 

545 return { 

546 "type": "OR", 

547 "clauses": [component.to_dict() for component in self.components], 

548 }