Base APIs

relax.base.BaseConfig

[source]

class relax.base.BaseConfig ()

Base class for all config classes.

class ConfigTest(BaseConfig):
    a: int = 1
    b: str = 'b'
    c: float = 3.14

conf = ConfigTest()
conf.save('test.json')
conf2 = ConfigTest.load_from_json('test.json')
assert conf == conf2
# remove test.json
os.remove('test.json')

conf = ConfigTest()
conf.save('tmp/test.json')
conf2 = ConfigTest.load_from_json('tmp/test.json')
assert conf == conf2
os.remove('tmp/test.json')

test_fail(lambda: conf.save('test'), contains="Path must end with `.json`,")
test_fail(lambda: ConfigTest.load_from_json('test.json'), contains="File not found")

relax.base.BaseModule

[source]

class relax.base.BaseModule (config, name=None)

Base class for all modules.

class TestModule(BaseModule):
    def save(self, path):
        self.config.save(Path(path) / 'config.json')

    def load_from_path(self, path):
        self.config = ConfigTest.load_from_json(Path(path) / 'config.json')

conf = ConfigTest()
module = TestModule(conf)
assert module.name == 'TestModule'
module.save('tmp/module/')
module.load_from_path('tmp/module/')
assert module.config == conf
shutil.rmtree('tmp/module/')

relax.base.PredFnMixedin

[source]

class relax.base.PredFnMixedin ()

Mixin class for modules that have a pred_fn method.

Methods

[source]

pred_fn (x)

Return the prediction/probability of the model on x.

relax.base.TrainableMixedin

[source]

class relax.base.TrainableMixedin ()

Mixin class for trainable modules.

Methods

[source]

is_trained ()

Return whether the module is trained or not.

[source]

train (data, **kwargs)

Train the module.