Note that the backbone and activations models are not created with keras.Input objects, but with the tensors that originate from keras.Input objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model, and use backbone or activations to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.
By subclassing the Model class
In that case, you should define your layers in __init__() and you should implement the model’s forward pass in call().
class MyModel(keras.Model):def__init__(self):super().__init__()self.dense1 = keras.layers.Dense(32, activation="relu")self.dense2 = keras.layers.Dense(5, activation="softmax")def call(self, inputs): x =self.dense1(inputs)returnself.dense2(x)model = MyModel()
If you subclass Model, you can optionally have a training argument (boolean) in call(), which you can use to specify a different behavior in training and inference:
class MyModel(keras.Model):def__init__(self):super().__init__()self.dense1 = keras.layers.Dense(32, activation="relu")self.dense2 = keras.layers.Dense(5, activation="softmax")self.dropout = keras.layers.Dropout(0.5)def call(self, inputs, training=False): x =self.dense1(inputs) x =self.dropout(x, training=training)returnself.dense2(x)model = MyModel()
Once the model is created, you can config the model with losses and metrics with model.compile(), train the model with model.fit(), or use the model to do prediction with model.predict().
With the Sequential class
In addition, keras.Sequential is a special case of model where the model is purely a stack of single-input, single-output layers.
model = keras.Sequential([ keras.Input(shape=(None, None, 3)), keras.layers.Conv2D(filters=32, kernel_size=3),])
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
warnings.warn("Passing `config` will have no effect.")