Linear Model

x = jrand.normal(jrand.PRNGKey(0), (100, 2))

source

l2_loss

 l2_loss (x1, x2, weights=None)

source

sgd_train_linear_model

 sgd_train_linear_model (X:jax.Array, y:jax.Array, weights:jax.Array=None,
                         lr:float=0.01, n_epochs:int=100,
                         batch_size:int=32, seed:int=42,
                         loss_fn:Callable=<function l2_loss>,
                         reg_term:int=None, alpha:float=1.0,
                         fit_bias:bool=True)

Train a linear model using SGD.

Type Default Details
X jnp.ndarray Input data. Shape: (N, k)
y jnp.ndarray Target data. Shape: (N,) or (N, 1)
weights jnp.ndarray None Initial weights. Shape: (N,)
lr float 0.01 Learning rate
n_epochs int 100 Number of epochs
batch_size int 32 Batch size
seed int 42 Random seed
loss_fn Callable l2_loss Loss function
reg_term int None Regularization term
alpha float 1.0 Regularization strength
fit_bias bool True Fit bias term
Returns Tuple[np.ndarray, np.ndarray] The trained weights and bias

source

calculate_loss

 calculate_loss (params:Dict[str,jax.Array],
                 batch:Tuple[jax.Array,jax.Array,jax.Array],
                 loss_fn:Callable, reg_term:int=None, alpha:float=1.0)

Calculate the loss for a batch of data.


source

BaseEstimator

 BaseEstimator ()

Initialize self. See help(type(self)) for accurate signature.


source

LinearModel

 LinearModel (intercept:bool=True, trainer_fn:Callable=None, **kwargs)

Initialize self. See help(type(self)) for accurate signature.


source

Lasso

 Lasso (alpha:float=1.0, **kwargs)

Initialize self. See help(type(self)) for accurate signature.


source

Ridge

 Ridge (alpha:float=1.0, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

Test

from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
X, y = make_regression(n_samples=500, n_features=20)
w = np.ones(X.shape[0])
sk_lm = LinearRegression()
sk_lm.fit(X, y)
sk_lm.coef_, sk_lm.intercept_
(array([ 4.39208757e+01,  5.56077362e+01,  8.41533489e+01, -1.96221110e-15,
        -3.12695418e-14,  2.31507874e-14, -3.24155956e-14,  2.37743954e+00,
         2.76497270e+01, -1.59962711e-15,  1.29899749e-14, -4.45028592e-14,
        -4.98098774e-14, -8.17722613e-14,  9.44247846e+01,  8.93984093e+01,
         5.23727300e+01, -8.15458812e-14,  7.31298648e+01,  4.14151921e+00]),
 -1.0658141036401503e-14)
lm = LinearModel()
lm.fit(X, y)
lm.fit(X, y, w)
lm.coef_, lm.intercept_
(Array([ 4.3920807e+01,  5.5607498e+01,  8.4153252e+01,  2.5172596e-04,
        -2.4921859e-05, -9.8353492e-05, -2.5988952e-04,  2.3774581e+00,
         2.7649561e+01,  3.1019867e-04, -2.3504299e-04,  2.3154756e-04,
         9.7808908e-05, -6.8332774e-05,  9.4424484e+01,  8.9398209e+01,
         5.2372467e+01, -2.0462631e-04,  7.3129379e+01,  4.1418076e+00],      dtype=float32),
 Array([1.1218661e-05], dtype=float32))
assert np.allclose(sk_lm.coef_, lm.coef_, atol=5e-4)
assert np.allclose(sk_lm.intercept_, lm.intercept_, atol=5e-4)
lasso = Lasso(alpha=0.1)
lasso.fit(X, y)
lasso.fit(X, y, w)
lasso.coef_, lasso.intercept_
(Array([ 4.3809612e+01,  5.5486099e+01,  8.4031708e+01, -2.9320540e-04,
         4.6559394e-04, -9.0057432e-04,  1.1306580e-03,  2.2790406e+00,
         2.7526842e+01, -1.0292386e-03, -5.8906851e-04,  1.5921631e-03,
        -1.3546057e-03, -5.7649292e-04,  9.4333282e+01,  8.9276718e+01,
         5.2267292e+01,  3.2355968e-04,  7.2997238e+01,  4.0658412e+00],      dtype=float32),
 Array([0.00556847], dtype=float32))
ridge = Ridge(alpha=0.1)
ridge.fit(X, y)
ridge.fit(X, y, w)
ridge.coef_, ridge.intercept_
(Array([ 4.3895161e+01,  5.5568810e+01,  8.4103256e+01,  1.9272733e-03,
        -1.3708844e-03, -2.7952294e-04, -5.8420287e-03,  2.3761270e+00,
         2.7629921e+01,  8.9110909e-03, -1.1335424e-03,  3.1227635e-03,
         1.2426455e-04, -9.2288008e-04,  9.4378990e+01,  8.9345421e+01,
         5.2339722e+01, -6.7419285e-04,  7.3078979e+01,  4.1462922e+00],      dtype=float32),
 Array([0.00056191], dtype=float32))