Transforms
rex.base.Transform
¤
A transformation that can be applied to parameters.
Can be used to normalize, denormalize, or transform parameters in any way.
init(*args: Any, **kwargs: Any) -> Transform
classmethod
¤
Initialize the transform.
Parameters:
-
*args(Any, default:()) –The arguments to initialize the transform.
-
**kwargs(Any, default:{}) –The keyword arguments to initialize the transform.
Returns:
-
Transform–The initialized transform.
apply(params: Dict[str, Params]) -> Dict[str, Params]
¤
Apply the transformation to the parameters.
Parameters:
-
params(Dict[str, Params]) –The original parameters.
Returns:
-
Dict[str, Params]–The transformed parameters.
inv(params: Dict[str, Params]) -> Dict[str, Params]
¤
Invert the transformation.
Parameters:
-
params(Dict[str, Params]) –The transformed parameters.
Returns:
-
Dict[str, Params]–The original parameters.
rex.base.Denormalize
¤
Bases: Transform
(De)normalize the parameters to/from a [-1, 1] range.
Attributes:
-
scale(Params) –The scale of the original parameters.
-
offset(Params) –The offset of the original parameters.
init(min_params: Params, max_params: Params) -> Denormalize
classmethod
¤
Initialize the denormalize transformation
Non-zero scale is required.
Therefore, the min and max values should be different for each parameter.
Parameters:
-
min_params(Params) –The minimum values of the original parameters.
-
max_params(Params) –The maximum values of the original parameters.
Returns:
-
Denormalize–The denormalize transformation
apply(params: Params) -> Params
¤
Apply the denormalize transformation to the parameters.
Parameters:
-
params(Params) –The normalized parameters.
Returns:
-
Params–The denormalized parameters.
rex.base.Extend
¤
Bases: Transform
Extend the structure of a pytree with additional parameters from another pytree.
Useful when you only want to optimize a subset of the parameters, but the full structure is required for simulation.
Example
from rex.base import Extend
base_params = {"a": {"b": 0, "c": "1"}, "d": 2}
opt_params = {"a": None, "d": 99}
transform = Extend.init(base_params, opt_params)
extended = transform.apply(opt_params) # {"a": {"b": 0, "c": "1"}, "b": 99}
filtered = transform.inv(extended) # {"a": None, "d": 99}
Attributes:
-
base_params(Params) –The base parameters.
-
mask(Params) –The mask of the extended parameters.
init(base_params: Params, opt_params: Params = None) -> Extend
classmethod
¤
Initialize the extend transformation.
Parameters:
-
base_params(Params) –The base parameters.
-
opt_params(Params, default:None) –The structure of the params that is going to be extended with the base parameters.
Returns:
-
Extend–The extend transformation.
apply(params: Params) -> Params
¤
Apply the extend transformation to the parameters.
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The extended parameters to the structure of the base parameters.
rex.base.Chain
¤
Bases: Transform
Chain multiple transformations together.
Attributes:
-
transforms(Sequence[Transform]) –The transformations to chain together.
init(*transforms: Sequence[Transform]) -> Chain
classmethod
¤
apply(params: Params) -> Params
¤
Apply the chain of transformations to the parameters.
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The transformed parameters.
inv(params: Params) -> Params
¤
Invert the chain of transformations.
Parameters:
-
params(Params) –The transformed parameters.
Returns:
-
Params–The original parameters.
rex.base.Shared
¤
Bases: Transform
A shared transformation that can be applied to parameters.
Useful to share parameters between different parts of the model.
Example
where_fn = lambda p: p["a"]
replace_fn = lambda p: p["b"]
inverse_fn = lambda p: None
transform = Shared.init(where=where_fn, replace_fn=replace_fn, inverse_fn=inverse_fn)
opt_params = {"a": 1, "b": 2}
applied = transform.apply(opt_params) # {"a": 2, "b": 2}
inverted = transform.inv(applied) # {"a": None, "b": 2}
Attributes:
-
where(Callable[[Any], Union[Any, Sequence[Any]]]) –The function that determines where to apply the transformation.
-
replace_fn(Callable[[Any], Union[Any, Sequence[Any]]]) –The function that replaces the parameters.
-
inverse_fn(Callable[[Any], Union[Any, Sequence[Any]]]) –The function that inverts the transformation.
init(where: Callable[[Any], Union[Any, Sequence[Any]]], replace_fn: Callable[[Any], Union[Any, Sequence[Any]]], inverse_fn: Callable[[Any], Union[Any, Sequence[Any]]] = lambda _tree: None) -> Shared
classmethod
¤
Initialize the shared transformation.
Parameters:
-
where(Callable[[Any], Union[Any, Sequence[Any]]]) –The function that determines where to apply the transformation.
-
replace_fn(Callable[[Any], Union[Any, Sequence[Any]]]) –The function that replaces the parameters.
-
inverse_fn(Callable[[Any], Union[Any, Sequence[Any]]], default:lambda _tree: None) –The function that inverts the transformation.
Returns:
-
Shared–The shared transformation.
apply(params: Params) -> Params
¤
Apply the shared transformation to the parameters.
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The transformed parameters.
rex.base.Identity
¤
Bases: Transform
The identity transformation (NOOP).
init() -> Identity
classmethod
¤
apply(params: Params) -> Params
¤
Apply the identity transformation (NOOP).
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The same parameters.
inv(params: Params) -> Params
¤
Invert the identity transformation (NOOP).
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The same parameters.
rex.base.Exponential
¤
Bases: Transform
Apply the exponential transformation to the parameters.
init() -> Exponential
classmethod
¤
apply(params: Params) -> Params
¤
Apply the exponential transformation to the parameters.
Parameters:
-
params(Params) –The parameters.
Returns:
-
Params–The transformed parameters.
inv(params: Params) -> Params
¤
Invert the exponential transformation.
Parameters:
-
params(Params) –The transformed parameters.
Returns:
-
Params–The original parameters.