= jrand.normal(jrand.PRNGKey(0), (100, 2)) x
Linear Model
l2_loss
l2_loss (x1, x2, weights=None)
sgd_train_linear_model
sgd_train_linear_model (X:jax.Array, y:jax.Array, weights:jax.Array=None, lr:float=0.01, n_epochs:int=100, batch_size:int=32, seed:int=42, loss_fn:Callable=<function l2_loss>, reg_term:int=None, alpha:float=1.0, fit_bias:bool=True)
Train a linear model using SGD.
Type | Default | Details | |
---|---|---|---|
X | jnp.ndarray | Input data. Shape: (N, k) |
|
y | jnp.ndarray | Target data. Shape: (N,) or (N, 1) |
|
weights | jnp.ndarray | None | Initial weights. Shape: (N,) |
lr | float | 0.01 | Learning rate |
n_epochs | int | 100 | Number of epochs |
batch_size | int | 32 | Batch size |
seed | int | 42 | Random seed |
loss_fn | Callable | l2_loss | Loss function |
reg_term | int | None | Regularization term |
alpha | float | 1.0 | Regularization strength |
fit_bias | bool | True | Fit bias term |
Returns | Tuple[np.ndarray, np.ndarray] | The trained weights and bias |
calculate_loss
calculate_loss (params:Dict[str,jax.Array], batch:Tuple[jax.Array,jax.Array,jax.Array], loss_fn:Callable, reg_term:int=None, alpha:float=1.0)
Calculate the loss for a batch of data.
BaseEstimator
BaseEstimator ()
Initialize self. See help(type(self)) for accurate signature.
LinearModel
LinearModel (intercept:bool=True, trainer_fn:Callable=None, **kwargs)
Initialize self. See help(type(self)) for accurate signature.
Lasso
Lasso (alpha:float=1.0, **kwargs)
Initialize self. See help(type(self)) for accurate signature.
Ridge
Ridge (alpha:float=1.0, **kwargs)
Initialize self. See help(type(self)) for accurate signature.
Test
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
= make_regression(n_samples=500, n_features=20)
X, y = np.ones(X.shape[0]) w
= LinearRegression()
sk_lm
sk_lm.fit(X, y) sk_lm.coef_, sk_lm.intercept_
(array([ 4.39208757e+01, 5.56077362e+01, 8.41533489e+01, -1.96221110e-15,
-3.12695418e-14, 2.31507874e-14, -3.24155956e-14, 2.37743954e+00,
2.76497270e+01, -1.59962711e-15, 1.29899749e-14, -4.45028592e-14,
-4.98098774e-14, -8.17722613e-14, 9.44247846e+01, 8.93984093e+01,
5.23727300e+01, -8.15458812e-14, 7.31298648e+01, 4.14151921e+00]),
-1.0658141036401503e-14)
= LinearModel()
lm
lm.fit(X, y)
lm.fit(X, y, w) lm.coef_, lm.intercept_
(Array([ 4.3920807e+01, 5.5607498e+01, 8.4153252e+01, 2.5172596e-04,
-2.4921859e-05, -9.8353492e-05, -2.5988952e-04, 2.3774581e+00,
2.7649561e+01, 3.1019867e-04, -2.3504299e-04, 2.3154756e-04,
9.7808908e-05, -6.8332774e-05, 9.4424484e+01, 8.9398209e+01,
5.2372467e+01, -2.0462631e-04, 7.3129379e+01, 4.1418076e+00], dtype=float32),
Array([1.1218661e-05], dtype=float32))
assert np.allclose(sk_lm.coef_, lm.coef_, atol=5e-4)
assert np.allclose(sk_lm.intercept_, lm.intercept_, atol=5e-4)
= Lasso(alpha=0.1)
lasso
lasso.fit(X, y)
lasso.fit(X, y, w) lasso.coef_, lasso.intercept_
(Array([ 4.3809612e+01, 5.5486099e+01, 8.4031708e+01, -2.9320540e-04,
4.6559394e-04, -9.0057432e-04, 1.1306580e-03, 2.2790406e+00,
2.7526842e+01, -1.0292386e-03, -5.8906851e-04, 1.5921631e-03,
-1.3546057e-03, -5.7649292e-04, 9.4333282e+01, 8.9276718e+01,
5.2267292e+01, 3.2355968e-04, 7.2997238e+01, 4.0658412e+00], dtype=float32),
Array([0.00556847], dtype=float32))
= Ridge(alpha=0.1)
ridge
ridge.fit(X, y)
ridge.fit(X, y, w) ridge.coef_, ridge.intercept_
(Array([ 4.3895161e+01, 5.5568810e+01, 8.4103256e+01, 1.9272733e-03,
-1.3708844e-03, -2.7952294e-04, -5.8420287e-03, 2.3761270e+00,
2.7629921e+01, 8.9110909e-03, -1.1335424e-03, 3.1227635e-03,
1.2426455e-04, -9.2288008e-04, 9.4378990e+01, 8.9345421e+01,
5.2339722e+01, -6.7419285e-04, 7.3078979e+01, 4.1462922e+00], dtype=float32),
Array([0.00056191], dtype=float32))