Coverage for sparkle/tools/configspace.py: 61%
241 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
1"""Extensions of the ConfigSpace lib."""
2from __future__ import annotations
3from typing_extensions import override, Iterable
4import ast
6import numpy as np
7from typing import Any
8from ConfigSpace import ConfigurationSpace
9from ConfigSpace.hyperparameters import Hyperparameter
10from ConfigSpace.types import Array, Mask, f64
12from ConfigSpace.conditions import (
13 Condition,
14 AndConjunction,
15 OrConjunction,
16 EqualsCondition,
17 GreaterThanCondition,
18 InCondition,
19 LessThanCondition,
20 NotEqualsCondition
21)
23from ConfigSpace.forbidden import (
24 ForbiddenGreaterThanRelation,
25 ForbiddenLessThanRelation,
26 ForbiddenClause,
27 ForbiddenConjunction,
28 ForbiddenRelation,
29 ForbiddenInClause,
30 ForbiddenEqualsClause,
31 ForbiddenAndConjunction
32)
34_SENTINEL = object()
37def expression_to_configspace(
38 expression: str | ast.Module,
39 configspace: ConfigurationSpace,
40 target_parameter: Hyperparameter = None) -> ForbiddenClause | Condition:
41 """Convert a logic expression to ConfigSpace expression.
43 Args:
44 expression: The expression to convert.
45 configspace: The ConfigSpace to use.
46 target_parameter: For conditions, will parse the expression as a condition
47 underwhich the parameter will be active.
48 """
49 if isinstance(expression, str):
50 try:
51 expression = ast.parse(expression)
52 except Exception as e:
53 raise ValueError(f"Could not parse expression: '{expression}', {e}")
54 if isinstance(expression, ast.Module):
55 expression = expression.body[0]
56 return recursive_conversion(expression, configspace,
57 target_parameter=target_parameter)
60def recursive_conversion(
61 item: ast.mod,
62 configspace: ConfigurationSpace,
63 target_parameter: Hyperparameter = None) -> ForbiddenClause | Condition:
64 """Recursively parse the AST tree to a ConfigSpace expression.
66 Args:
67 item: The item to parse.
68 configspace: The ConfigSpace to use.
69 target_parameter: For conditions, will parse the expression as a condition
70 underwhich the parameter will be active.
72 Returns:
73 A ConfigSpace expression
74 """
75 if isinstance(item, list):
76 if len(item) > 1:
77 raise ValueError(f"Can not parse list of elements: {item}.")
78 item = item[0]
79 if isinstance(item, ast.Expr):
80 return recursive_conversion(item.value, configspace, target_parameter)
81 if isinstance(item, ast.Name): # Convert to hyperparameter
82 hp = configspace.get(item.id)
83 return hp if hp is not None else item.id
84 if isinstance(item, ast.Constant):
85 return item.value
86 if (isinstance(item, ast.Tuple)
87 or isinstance(item, ast.Set) or isinstance(item, ast.List)):
88 values = []
89 for v in item.elts:
90 if isinstance(v, ast.Constant):
91 values.append(v.value)
92 elif isinstance(v, ast.Name): # Check if its a parameter
93 if v.id in list(configspace.values()):
94 raise ValueError("Only constants allowed in tuples. "
95 f"Found: {item.elts}")
96 values.append(v.id) # String value was interpreted as parameter
97 return values
98 if isinstance(item, ast.BinOp):
99 raise NotImplementedError("Binary operations not supported by ConfigSpace.")
100 if isinstance(item, ast.BoolOp):
101 values = [recursive_conversion(v, configspace, target_parameter)
102 for v in item.values]
103 if isinstance(item.op, ast.Or):
104 if target_parameter:
105 return OrConjunction(*values)
106 return ForbiddenOrConjunction(*values)
107 elif isinstance(item.op, ast.And):
108 if target_parameter:
109 return AndConjunction(*values)
110 return ForbiddenAndConjunction(*values)
111 else:
112 raise ValueError(f"Unknown boolean operator: {item.op}")
113 if isinstance(item, ast.Compare):
114 if len(item.ops) > 1:
115 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}")
116 left = recursive_conversion(item.left, configspace, target_parameter)
117 right = recursive_conversion(item.comparators, configspace, target_parameter)
118 operator = item.ops[0]
119 if isinstance(left, Hyperparameter): # Convert to HP type
120 if isinstance(right, Iterable) and not isinstance(right, str):
121 right = [type(left.default_value)(v) for v in right]
122 if len(right) == 1 and not isinstance(operator, ast.In):
123 right = right[0]
124 elif isinstance(right, int):
125 right = type(left.default_value)(right)
127 if isinstance(operator, ast.Lt):
128 if target_parameter:
129 return LessThanCondition(target_parameter, left, right)
130 return ForbiddenLessThanRelation(left=left, right=right)
131 if isinstance(operator, ast.LtE):
132 if target_parameter:
133 raise ValueError("LessThanEquals not supported for conditions.")
134 return ForbiddenLessThanEqualsRelation(left=left, right=right)
135 if isinstance(operator, ast.Gt):
136 if target_parameter:
137 return GreaterThanCondition(target_parameter, left, right)
138 return ForbiddenGreaterThanRelation(left=left, right=right)
139 if isinstance(operator, ast.GtE):
140 if target_parameter:
141 raise ValueError("GreaterThanEquals not supported for conditions.")
142 return ForbiddenGreaterThanEqualsRelation(left=left, right=right)
143 if isinstance(operator, ast.Eq):
144 if target_parameter:
145 return EqualsCondition(target_parameter, left, right)
146 return ForbiddenEqualsClause(hyperparameter=left, value=right)
147 if isinstance(operator, ast.In):
148 if target_parameter:
149 return InCondition(target_parameter, left, right)
150 return ForbiddenInClause(hyperparameter=left, values=right)
151 if isinstance(operator, ast.NotEq):
152 if target_parameter:
153 return NotEqualsCondition(target_parameter, left, right)
154 raise ValueError("NotEq operator not supported for ForbiddenClauses.")
155 # The following classes do not (yet?) exist in configspace
156 if isinstance(operator, ast.NotIn):
157 raise ValueError("NotIn operator not supported for ForbiddenClauses.")
158 if isinstance(operator, ast.Is):
159 raise NotImplementedError("Is operator not supported.")
160 if isinstance(operator, ast.IsNot):
161 raise NotImplementedError("IsNot operator not supported.")
162 raise ValueError(f"Unsupported type: {item}")
165class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation):
166 """A ForbiddenLessThanEquals relation between two hyperparameters."""
168 _RELATION_STR = "LESSEQUAL"
170 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str:
171 """Return a string representation of the ForbiddenLessThanEqualsRelation."""
172 return f"Forbidden: {self.left.name} <= {self.right.name}"
174 @override
175 def is_forbidden_value(self: ForbiddenLessThanEqualsRelation,
176 values: dict[str, Any]) -> bool:
177 """Check if the value is forbidden."""
178 # Relation is always evaluated against actual value and not vector rep
179 left = values.get(self.left.name, _SENTINEL)
180 if left is _SENTINEL:
181 return False
183 right = values.get(self.right.name, _SENTINEL)
184 if right is _SENTINEL:
185 return False
187 return left <= right # type: ignore
189 @override
190 def is_forbidden_vector(self: ForbiddenLessThanEqualsRelation,
191 vector: Array[f64]) -> bool:
192 """Check if the vector is forbidden."""
193 # Relation is always evaluated against actual value and not vector rep
194 left: f64 = vector[self.vector_ids[0]] # type: ignore
195 right: f64 = vector[self.vector_ids[1]] # type: ignore
196 if np.isnan(left) or np.isnan(right):
197 return False
198 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore
200 @override
201 def is_forbidden_vector_array(self: ForbiddenLessThanEqualsRelation,
202 arr: Array[f64]) -> Mask:
203 """Check if the vector array is forbidden."""
204 left = arr[self.vector_ids[0]]
205 right = arr[self.vector_ids[1]]
206 valid = ~(np.isnan(left) | np.isnan(right))
207 out = np.zeros_like(valid)
208 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid])
209 return out
212class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation):
213 """A ForbiddenGreaterThanEquals relation between two hyperparameters."""
215 _RELATION_STR = "GREATEREQUAL"
217 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str:
218 """Return a string representation of the ForbiddenGreaterThanEqualsRelation."""
219 return f"Forbidden: {self.left.name} >= {self.right.name}"
221 @override
222 def is_forbidden_value(self: ForbiddenGreaterThanEqualsRelation,
223 values: dict[str, Any]) -> bool:
224 """Check if the value is forbidden."""
225 left = values.get(self.left.name, _SENTINEL)
226 if left is _SENTINEL:
227 return False
229 right = values.get(self.right.name, _SENTINEL)
230 if right is _SENTINEL:
231 return False
233 return left >= right # type: ignore
235 @override
236 def is_forbidden_vector(self: ForbiddenGreaterThanEqualsRelation,
237 vector: Array[f64]) -> bool:
238 """Check if the vector is forbidden."""
239 # Relation is always evaluated against actual value and not vector rep
240 left: f64 = vector[self.vector_ids[0]] # type: ignore
241 right: f64 = vector[self.vector_ids[1]] # type: ignore
242 if np.isnan(left) or np.isnan(right):
243 return False
244 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore
246 @override
247 def is_forbidden_vector_array(self: ForbiddenGreaterThanEqualsRelation,
248 arr: Array[f64]) -> Mask:
249 """Check if the vector array is forbidden."""
250 left = arr[self.vector_ids[0]]
251 right = arr[self.vector_ids[1]]
252 valid = ~(np.isnan(left) | np.isnan(right))
253 out = np.zeros_like(valid)
254 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid])
255 return out
258class ForbiddenGreaterThanClause(ForbiddenEqualsClause):
259 """A ForbiddenGreaterThanClause.
261 It forbids a value from the value range of a hyperparameter to be
262 *greater than* `value`.
264 Forbids the value of the hyperparameter *a* to be greater than 2
266 Args:
267 hyperparameter: Methods on which a restriction will be made
268 value: forbidden value
269 """
271 def __repr__(self: ForbiddenGreaterThanClause) -> str:
272 """Return a string representation of the ForbiddenGreaterThanClause."""
273 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}"
275 @override
276 def is_forbidden_value(self: ForbiddenGreaterThanClause,
277 values: dict[str, Any]) -> bool:
278 """Check if the value is forbidden."""
279 return ( # type: ignore
280 values.get(self.hyperparameter.name, _SENTINEL) > self.value
281 )
283 @override
284 def is_forbidden_vector(self: ForbiddenGreaterThanClause,
285 vector: Array[f64]) -> bool:
286 """Check if the vector is forbidden."""
287 return vector[self.vector_id] > self.vector_value # type: ignore
289 @override
290 def is_forbidden_vector_array(self: ForbiddenGreaterThanClause,
291 arr: Array[f64]) -> Mask:
292 """Check if the vector array is forbidden."""
293 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_)
295 @override
296 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]:
297 """Convert the ForbiddenGreaterThanClause to a dictionary."""
298 return {
299 "name": self.hyperparameter.name,
300 "type": "GREATER",
301 "value": self.value,
302 }
305class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause):
306 """A ForbiddenGreaterEqualsClause.
308 It forbids a value from the value range of a hyperparameter to be
309 *greater or equal to* `value`.
311 Forbids the value of the hyperparameter *a* to be greater or equal to 2
313 Args:
314 hyperparameter: Methods on which a restriction will be made
315 value: forbidden value
316 """
318 def __repr__(self: ForbiddenGreaterEqualsClause) -> str:
319 """Return a string representation of the ForbiddenGreaterEqualsClause."""
320 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}"
322 @override
323 def is_forbidden_value(self: ForbiddenGreaterEqualsClause,
324 values: dict[str, Any]) -> bool:
325 """Check if the value is forbidden."""
326 return ( # type: ignore
327 values.get(self.hyperparameter.name, _SENTINEL) >= self.value
328 )
330 @override
331 def is_forbidden_vector(self: ForbiddenGreaterEqualsClause,
332 vector: Array[f64]) -> bool:
333 """Check if the vector is forbidden."""
334 return vector[self.vector_id] >= self.vector_value # type: ignore
336 @override
337 def is_forbidden_vector_array(self: ForbiddenGreaterEqualsClause,
338 arr: Array[f64]) -> Mask:
339 """Check if the vector array is forbidden."""
340 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
342 @override
343 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]:
344 """Convert the ForbiddenGreaterEqualsClause to a dictionary."""
345 return {
346 "name": self.hyperparameter.name,
347 "type": "GREATEREQUAL",
348 "value": self.value,
349 }
352class ForbiddenLessThanClause(ForbiddenEqualsClause):
353 """A ForbiddenLessThanClause.
355 It forbids a value from the value range of a hyperparameter to be
356 *less than* `value`.
358 Args:
359 hyperparameter: Methods on which a restriction will be made
360 value: forbidden value
361 """
363 def __repr__(self: ForbiddenLessThanClause) -> str:
364 """Return a string representation of the ForbiddenLessThanClause."""
365 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}"
367 @override
368 def is_forbidden_value(self: ForbiddenLessThanClause,
369 values: dict[str, Any]) -> bool:
370 """Check if the value is forbidden."""
371 return ( # type: ignore
372 values.get(self.hyperparameter.name, _SENTINEL) < self.value
373 )
375 @override
376 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool:
377 """Check if the vector is forbidden."""
378 return vector[self.vector_id] < self.vector_value # type: ignore
380 @override
381 def is_forbidden_vector_array(self: ForbiddenLessThanClause,
382 arr: Array[f64]) -> Mask:
383 """Check if the vector array is forbidden."""
384 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_)
386 @override
387 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]:
388 """Convert the ForbiddenLessThanClause to a dictionary."""
389 return {
390 "name": self.hyperparameter.name,
391 "type": "LESS",
392 "value": self.value,
393 }
396class ForbiddenLessEqualsClause(ForbiddenEqualsClause):
397 """A ForbiddenLessEqualsClause.
399 It forbids a value from the value range of a hyperparameter to be
400 *less or equal to* `value`.
402 Args:
403 hyperparameter: Methods on which a restriction will be made
404 value: forbidden value
405 """
407 def __repr__(self: ForbiddenLessEqualsClause) -> str:
408 """Return a string representation of the ForbiddenLessEqualsClause."""
409 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}"
411 @override
412 def is_forbidden_value(self: ForbiddenLessEqualsClause,
413 values: dict[str, Any]) -> bool:
414 """Check if the value is forbidden."""
415 return ( # type: ignore
416 values.get(self.hyperparameter.name, _SENTINEL) <= self.value
417 )
419 @override
420 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool:
421 """Check if the vector is forbidden."""
422 return vector[self.vector_id] <= self.vector_value # type: ignore
424 @override
425 def is_forbidden_vector_array(self: ForbiddenLessEqualsClause,
426 arr: Array[f64]) -> Mask:
427 """Check if the vector array is forbidden."""
428 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
430 @override
431 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]:
432 """Convert the ForbiddenLessEqualsClause to a dictionary."""
433 return {
434 "name": self.hyperparameter.name,
435 "type": "LESSEQUAL",
436 "value": self.value,
437 }
440class ForbiddenOrConjunction(ForbiddenConjunction):
441 """A ForbiddenOrConjunction.
443 The ForbiddenOrConjunction combines forbidden-clauses, which allows to
444 build powerful constraints.
446 ```python exec="true", source="material-block" result="python"
447 from ConfigSpace import (
448 ConfigurationSpace,
449 ForbiddenEqualsClause,
450 ForbiddenInClause,
451 )
452 from sparkle.tools.configspace import ForbiddenOrConjunction
454 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]})
455 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2)
456 forbidden_clause_b = ForbiddenInClause(cs["b"], [2])
458 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b)
460 cs.add(forbidden_clause)
461 print(cs)
462 ```
464 Args:
465 *args: forbidden clauses, which should be combined
466 """
468 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...]
469 """Components of the conjunction."""
471 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...]
472 """Descendant literal clauses of the conjunction.
474 These are the base forbidden clauses/relations that are part of conjunctions.
476 !!! note
478 This will only store a unique set of the descendant clauses, no duplicates.
479 """
481 def __repr__(self: ForbiddenOrConjunction) -> str:
482 """Return a string representation of the ForbiddenOrConjunction."""
483 return "(" + " || ".join([str(c) for c in self.components]) + ")"
485 @override
486 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool:
487 """Check if the value is forbidden."""
488 return any([forbidden.is_forbidden_value(values)
489 for forbidden in self.components])
491 @override
492 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool:
493 """Check if the vector is forbidden."""
494 return any(
495 forbidden.is_forbidden_vector(vector) for forbidden in self.components
496 )
498 @override
499 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask:
500 """Check if the vector array is forbidden."""
501 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_)
502 for forbidden in self.components:
503 forbidden_mask |= forbidden.is_forbidden_vector_array(arr)
505 return forbidden_mask
507 @override
508 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]:
509 """Convert the ForbiddenOrConjunction to a dictionary."""
510 return {
511 "type": "OR",
512 "clauses": [component.to_dict() for component in self.components],
513 }