module documentation

Undocumented

Function get_model Get model class from identifier string:
Function get_summary Print summary of model with identifier model_string. Input-shape and batch-size are needed to compute the estimated memory-footprint for forward- and backward-passes.
Function load_model Undocumented
Function load_model_from_metafile Undocumented
Function load_model_from_state_dict Undocumented
Variable custom_models_dict Undocumented
Variable models_with_channel_factor_dict Undocumented
def get_model(model_string, **kwargs): (source)

Get model class from identifier string:

    Saddler et al 2020 dnn_modelling:
        model_string = "saddler-???" where ??? should match one of the following:
            TODO

    Kell et al 2018 dnn_modelling:
        model_string = "kell"
            returns a model with the same architecture but with no input split and number of output classes
        model_string = "kell-???"
            returns same as above but with the number of channels reduced as $n // ??$, ??=1 makes no changes
        model_string = "kell-1?-2?"
            same as above for 1?, where 2? sets the size of the hidden fully connected layer
Parameters
model_stringmodel identifier string
**kwargs
Returns
(Any, bool)[0] the actual model as a class [1] boolean flag, indicating whether the model is based on PyTorch
def get_summary(model_string, input_shape, batch_size, **kwargs): (source)

Print summary of model with identifier model_string. Input-shape and batch-size are needed to compute the estimated memory-footprint for forward- and backward-passes.

Parameters
model_stringstring identifier for the model to get summary of
input_shapeshape of expected input matrix
batch_sizebatch size to use when training model
**kwargsother kwargs passed on to get_model
Returns
the model that was summarised
def load_model(saved_model_string, device=None, **kwargs): (source)

Undocumented

def load_model_from_metafile(metafile, state_dict_name, device=None, **kwargs): (source)

Undocumented

def load_model_from_state_dict(model_string, state_dict_path=None, device=None, **kwargs): (source)

Undocumented

custom_models_dict = (source)

Undocumented

models_with_channel_factor_dict = (source)

Undocumented