Transforms
rex.base.Transform
¤
A transformation that can be applied to parameters.
Can be used to normalize, denormalize, or transform parameters in any way.
init(*args: Any, **kwargs: Any) -> Transform
classmethod
¤
Initialize the transform.
Parameters:
-
*args
(Any
, default:()
) –The arguments to initialize the transform.
-
**kwargs
(Any
, default:{}
) –The keyword arguments to initialize the transform.
Returns:
-
Transform
–The initialized transform.
apply(params: Dict[str, Params]) -> Dict[str, Params]
¤
Apply the transformation to the parameters.
Parameters:
-
params
(Dict[str, Params]
) –The original parameters.
Returns:
-
Dict[str, Params]
–The transformed parameters.
inv(params: Dict[str, Params]) -> Dict[str, Params]
¤
Invert the transformation.
Parameters:
-
params
(Dict[str, Params]
) –The transformed parameters.
Returns:
-
Dict[str, Params]
–The original parameters.
rex.base.Denormalize
¤
Bases: Transform
(De)normalize the parameters to/from a [-1, 1] range.
Attributes:
-
scale
(Params
) –The scale of the original parameters.
-
offset
(Params
) –The offset of the original parameters.
init(min_params: Params, max_params: Params) -> Denormalize
classmethod
¤
Initialize the denormalize transformation
Non-zero scale is required.
Therefore, the min and max values should be different for each parameter.
Parameters:
-
min_params
(Params
) –The minimum values of the original parameters.
-
max_params
(Params
) –The maximum values of the original parameters.
Returns:
-
Denormalize
–The denormalize transformation
apply(params: Params) -> Params
¤
Apply the denormalize transformation to the parameters.
Parameters:
-
params
(Params
) –The normalized parameters.
Returns:
-
Params
–The denormalized parameters.
rex.base.Extend
¤
Bases: Transform
Extend the structure of a pytree with additional parameters from another pytree.
Useful when you only want to optimize a subset of the parameters, but the full structure is required for simulation.
Example
from rex.base import Extend
base_params = {"a": {"b": 0, "c": "1"}, "d": 2}
opt_params = {"a": None, "d": 99}
transform = Extend.init(base_params, opt_params)
extended = transform.apply(opt_params) # {"a": {"b": 0, "c": "1"}, "b": 99}
filtered = transform.inv(extended) # {"a": None, "d": 99}
Attributes:
-
base_params
(Params
) –The base parameters.
-
mask
(Params
) –The mask of the extended parameters.
init(base_params: Params, opt_params: Params = None) -> Extend
classmethod
¤
Initialize the extend transformation.
Parameters:
-
base_params
(Params
) –The base parameters.
-
opt_params
(Params
, default:None
) –The structure of the params that is going to be extended with the base parameters.
Returns:
-
Extend
–The extend transformation.
apply(params: Params) -> Params
¤
Apply the extend transformation to the parameters.
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The extended parameters to the structure of the base parameters.
rex.base.Chain
¤
Bases: Transform
Chain multiple transformations together.
Attributes:
-
transforms
(Sequence[Transform]
) –The transformations to chain together.
init(*transforms: Sequence[Transform]) -> Chain
classmethod
¤
apply(params: Params) -> Params
¤
Apply the chain of transformations to the parameters.
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The transformed parameters.
inv(params: Params) -> Params
¤
Invert the chain of transformations.
Parameters:
-
params
(Params
) –The transformed parameters.
Returns:
-
Params
–The original parameters.
rex.base.Shared
¤
Bases: Transform
A shared transformation that can be applied to parameters.
Useful to share parameters between different parts of the model.
Example
where_fn = lambda p: p["a"]
replace_fn = lambda p: p["b"]
inverse_fn = lambda p: None
transform = Shared.init(where=where_fn, replace_fn=replace_fn, inverse_fn=inverse_fn)
opt_params = {"a": 1, "b": 2}
applied = transform.apply(opt_params) # {"a": 2, "b": 2}
inverted = transform.inv(applied) # {"a": None, "b": 2}
Attributes:
-
where
(Callable[[Any], Union[Any, Sequence[Any]]]
) –The function that determines where to apply the transformation.
-
replace_fn
(Callable[[Any], Union[Any, Sequence[Any]]]
) –The function that replaces the parameters.
-
inverse_fn
(Callable[[Any], Union[Any, Sequence[Any]]]
) –The function that inverts the transformation.
init(where: Callable[[Any], Union[Any, Sequence[Any]]], replace_fn: Callable[[Any], Union[Any, Sequence[Any]]], inverse_fn: Callable[[Any], Union[Any, Sequence[Any]]] = lambda _tree: None) -> Shared
classmethod
¤
Initialize the shared transformation.
Parameters:
-
where
(Callable[[Any], Union[Any, Sequence[Any]]]
) –The function that determines where to apply the transformation.
-
replace_fn
(Callable[[Any], Union[Any, Sequence[Any]]]
) –The function that replaces the parameters.
-
inverse_fn
(Callable[[Any], Union[Any, Sequence[Any]]]
, default:lambda _tree: None
) –The function that inverts the transformation.
Returns:
-
Shared
–The shared transformation.
apply(params: Params) -> Params
¤
Apply the shared transformation to the parameters.
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The transformed parameters.
rex.base.Identity
¤
Bases: Transform
The identity transformation (NOOP).
init() -> Identity
classmethod
¤
apply(params: Params) -> Params
¤
Apply the identity transformation (NOOP).
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The same parameters.
inv(params: Params) -> Params
¤
Invert the identity transformation (NOOP).
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The same parameters.
rex.base.Exponential
¤
Bases: Transform
Apply the exponential transformation to the parameters.
init() -> Exponential
classmethod
¤
apply(params: Params) -> Params
¤
Apply the exponential transformation to the parameters.
Parameters:
-
params
(Params
) –The parameters.
Returns:
-
Params
–The transformed parameters.
inv(params: Params) -> Params
¤
Invert the exponential transformation.
Parameters:
-
params
(Params
) –The transformed parameters.
Returns:
-
Params
–The original parameters.