Source code for skdownscale.pointwise_models.trend
from __future__ import annotations
from typing import Any
import numpy as np
from numpy.typing import ArrayLike, NDArray
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.linear_model import LinearRegression
from sklearn.utils.validation import check_is_fitted, validate_data
from .utils import default_none_kwargs
[docs]
class LinearTrendTransformer(TransformerMixin, BaseEstimator):
"""Transform features by removing linear trends.
Uses Ordinary least squares Linear Regression as implemented in
sklear.linear_model.LinearRegression.
Parameters
----------
**lr_kwargs
Keyword arguments to pass to sklearn.linear_model.LinearRegression
Attributes
----------
lr_model_ : sklearn.linear_model.LinearRegression
Linear Regression object.
"""
[docs]
def __init__(self, lr_kwargs: dict[str, Any] | None = None) -> None:
self.lr_kwargs = lr_kwargs
def _validate_data(
self, X: ArrayLike, y: ArrayLike | None = None, reset: bool = True, **check_params: Any
) -> ArrayLike | tuple[ArrayLike, ArrayLike]:
"""Validate input data using sklearn's validate_data."""
return validate_data(self, X=X, y=y, reset=reset, **check_params)
[docs]
def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> LinearTrendTransformer:
"""Compute the linear trend.
Parameters
----------
X : array-like, shape [n_samples, n_features]
Training data.
"""
X = self._validate_data(X)
kwargs = default_none_kwargs(self.lr_kwargs)
self.lr_model_ = LinearRegression(**kwargs)
self.lr_model_.fit(np.arange(len(X)).reshape(-1, 1), X)
return self
[docs]
def transform(self, X: ArrayLike) -> NDArray[Any]:
"""Perform transformation by removing the trend.
Parameters
----------
X : array-like, shape [n_samples, n_features]
The data that should be detrended.
"""
# validate input data
check_is_fitted(self)
X = self._validate_data(X)
return X - self.trendline(X)
[docs]
def inverse_transform(self, X: ArrayLike) -> NDArray[Any]:
"""Add the trend back to the data.
Parameters
----------
X : array-like, shape [n_samples, n_features]
The data that should be transformed back.
"""
# validate input data
check_is_fitted(self)
X = self._validate_data(X)
return X + self.trendline(X)
[docs]
def trendline(self, X: ArrayLike) -> NDArray[Any]:
"""helper function to calculate a linear trendline"""
X = self._validate_data(X)
return self.lr_model_.predict(np.arange(len(X)).reshape(-1, 1))
def __sklearn_tags__(self):
from dataclasses import replace
tags = super().__sklearn_tags__()
# Mark as skipping certain tests due to temporal sensitivity
tags = replace(tags, _skip_test='Temporal transformer - sample order matters')
return tags