Skip to content

lacss.train

class description
lacss.train.Core No model JIT compiling, i.e., for debugging
lacss.train.JIT JIT compile the model. Default strategy
lacss.train.VMapped Transform the model with vmap. This allows defining a model on unbatched data but train with batched data.
lacss.train.Distributed Transform the model with pmap. This allows training the model on multiple devices.

lacss.train.TFDatasetAdapter

Convert tf.data.Dataset into a python iterable suitable for lacss.train.Trainer

my_dataset = TFDatasetAdapter(my_tf_dataset)

lacss.train.TorchDataLoaderAdapter

Convert torch dataloader into a python iterable suitable for lacss.Trainer

my_dataset = TorchDataLoaderAdapter(my_torch_dataloader)

lacss.train.Trainer dataclass

A general purpose FLAX model trainer. Help avoiding most of the biolerplate code when trainning with FLAX.

Attributes:

Name Type Description
model Module

A Flax module

losses LOSSES

A collection of loss function ( loss_fn(batch:Any, prediction:Any)->float ).

optimizer Optimizer

An optax optimizer

seed int | RNG

RNG seed

strategy type

a training strategy type

Example:

```
trainer = lacss.train.Trainer(my_module, my_loss_func)

train_it = trainer.train(my_dataset)

for k in range(train_steps):
    _ = next(train_it)
    if k % 1000 == 0:
        print(train_it.loss_logs)
        train_it.reset_loss_logs()
```

compute_metrics(*args, **kwargs)

A convient function to compute all metrics. See test() fucntion

Returns:

Type Description
dict

A metric dict. Keys are metric names.

test(dataset, metrics, variables, strategy=None, **kwargs)

Create test/validation iterator.

Parameters:

Name Type Description Default
dataset Iterable

An iterator or iterable to supply the testing data. The iterator should yield a tupple of (inputs, labels).

required
metrics METRICS

A list of Metric objects. They should have two functions: m.update(batch, prediction): prediction is the model output, batch is a tuple of (x, y_true) m.compute(): which should return the accumulated metric value.

required
variables dict

Model weights etc. typically get from TrainIterator

required
strategy type | None

Optionally override the default strategy.

None

Returns:

Type Description
Iterator

An iterator. Stepping through it will drive the updating of each metric obj. The iterator itself return the list of metrics.

train(dataset, strategy=None, rng_cols=[], init_vars=None, frozen=None, **kwargs)

Create the training iterator

Parameters:

Name Type Description Default
dataset Iterable

An iterator or iterable to supply the training data. The dataset should produce (inputs, labels, sample_weight), however both the labels and the sample_weight are optional. The inputs is either a tuple or a dict. If the inputs is a dict, the keys are interpreted as the names for keyword args of the model's call function.

required
strategy type | None

Optionally override the default strategy.

None
rng_cols Sequence[str]

Names of any RNG used by the model. Should be a list of strings.

[]
init_vars dict | None

optional variables to initialize model

None
frozen dict | None

a pytree indicating frozen parameters

None
**kwargs

Additional keyward args passed to the model. E.g. "training=True"

{}

Returns:

Type Description
TrainIterator

TrainIterator. Stepping through the iterator will train the model.

lacss.train.LacssTrainer

Main trainer class for Lacss

__init__(config={}, collaborator_config=None, *, optimizer=None, seed=42, strategy=JIT)

Constructor

Parameters:

Name Type Description Default
config dict

configuration dictionary for Lacss model

{}
collaborator_config Optional[dict]

configuration dictionary for the collaborator model used in weakly-supervised training. If set to None, then no collaborator model will be created. In this case, training with weak-supervision will result in a error.

None

Other Parameters:

Name Type Description
seed Union[int, Array]

RNG seed

optimizer Optional[Optimizer]

Override the default optimizer

strategy type

Training backend. See See Traing backends.

do_training(dataset, val_dataset=None, n_steps=50000, validation_interval=5000, checkpoint_manager=None, *, warmup_steps=0, sigma=20.0, pi=2.0, init_vars=None)

Runing training.

Parameters:

Name Type Description Default
dataset Iterable

An data iterator feed training data. The data should be in the form of a tuple: (x, y). x is a dictionary with at least two keys: "image": Trainging image. "gt_locations": Point labels. Nx2 array y is a dictionary of extra labels. It can be None: for point-supervised training "gt_labels": A index-label image (H, W). For segmentation label. "gt_image_mask": A binary image (H, W). For weakly supervised training

required
n_steps int

Total training steps

50000
validation_inteval

Step intervals to perform validation and checkpointing.

required
val_dataset Iterable | None

If not None, performing validation on this dataset. The data should be in the form of a tuple (x, y): x is a dictionary with one key: "image" y is a dictionary with two lables: "gt_bboxes" and "gt_locations"

None
checkpoint_manager Optional[CheckpointManager]

If supplied, will be used to created checkpoints. A checkpoint manager can be obtained by calling:

options = orbax.CheckpointManagerOptions(...)
manager = orbax.checkpoint.CheckpointManager(
    'path/to/directory/',
    options = options
)
None

Other Parameters:

Name Type Description
warmup_steps int

Only used for point-supervised training. Pretraining steps, for which a large sigma values is used. This should be multiples of validation_inteval

sigma float

Only for point-supervised training. Expected cell size

pi float

Only for point-supervised training. Amplitude of the prior term.

save(save_path)

Save a pickled copy of the Lacss model in the form of (module:Lacss, weights:FrozenDict). Only saves the principal model.

Parameters:

Name Type Description Default
save_path

Path to the pkl file

required