lacss.train
class | description |
---|---|
lacss.train.Core | No model JIT compiling, i.e., for debugging |
lacss.train.JIT | JIT compile the model. Default strategy |
lacss.train.VMapped | Transform the model with vmap. This allows defining a model on unbatched data but train with batched data. |
lacss.train.Distributed | Transform the model with pmap. This allows training the model on multiple devices. |
lacss.train.TFDatasetAdapter
Convert tf.data.Dataset
into a python iterable suitable for lacss.train.Trainer
my_dataset = TFDatasetAdapter(my_tf_dataset)
lacss.train.TorchDataLoaderAdapter
Convert torch dataloader into a python iterable suitable for lacss.Trainer
my_dataset = TorchDataLoaderAdapter(my_torch_dataloader)
lacss.train.Trainer
dataclass
A general purpose FLAX model trainer. Help avoiding most of the biolerplate code when trainning with FLAX.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
A Flax module |
losses |
LOSSES
|
A collection of loss function ( loss_fn(batch:Any, prediction:Any)->float ). |
optimizer |
Optimizer
|
An optax optimizer |
seed |
int | RNG
|
RNG seed |
strategy |
type
|
a training strategy type |
Example:
```
trainer = lacss.train.Trainer(my_module, my_loss_func)
train_it = trainer.train(my_dataset)
for k in range(train_steps):
_ = next(train_it)
if k % 1000 == 0:
print(train_it.loss_logs)
train_it.reset_loss_logs()
```
compute_metrics(*args, **kwargs)
A convient function to compute all metrics. See test() fucntion
Returns:
Type | Description |
---|---|
dict
|
A metric dict. Keys are metric names. |
test(dataset, metrics, variables, strategy=None, **kwargs)
Create test/validation iterator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Iterable
|
An iterator or iterable to supply the testing data. The iterator should yield a tupple of (inputs, labels). |
required |
metrics |
METRICS
|
A list of Metric objects. They should have two functions: m.update(batch, prediction): prediction is the model output, batch is a tuple of (x, y_true) m.compute(): which should return the accumulated metric value. |
required |
variables |
dict
|
Model weights etc. typically get from TrainIterator |
required |
strategy |
type | None
|
Optionally override the default strategy. |
None
|
Returns:
Type | Description |
---|---|
Iterator
|
An iterator. Stepping through it will drive the updating of each metric obj. The iterator itself return the list of metrics. |
train(dataset, strategy=None, rng_cols=[], init_vars=None, frozen=None, **kwargs)
Create the training iterator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Iterable
|
An iterator or iterable to supply the training data.
The dataset should produce |
required |
strategy |
type | None
|
Optionally override the default strategy. |
None
|
rng_cols |
Sequence[str]
|
Names of any RNG used by the model. Should be a list of strings. |
[]
|
init_vars |
dict | None
|
optional variables to initialize model |
None
|
frozen |
dict | None
|
a pytree indicating frozen parameters |
None
|
**kwargs |
Additional keyward args passed to the model. E.g. "training=True" |
{}
|
Returns:
Type | Description |
---|---|
TrainIterator
|
TrainIterator. Stepping through the iterator will train the model. |
lacss.train.LacssTrainer
Main trainer class for Lacss
__init__(config={}, collaborator_config=None, *, optimizer=None, seed=42, strategy=JIT)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict
|
configuration dictionary for Lacss model |
{}
|
collaborator_config |
Optional[dict]
|
configuration dictionary for the collaborator model used in weakly-supervised training. If set to None, then no collaborator model will be created. In this case, training with weak-supervision will result in a error. |
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
seed |
Union[int, Array]
|
RNG seed |
optimizer |
Optional[Optimizer]
|
Override the default optimizer |
strategy |
type
|
Training backend. See See Traing backends. |
do_training(dataset, val_dataset=None, n_steps=50000, validation_interval=5000, checkpoint_manager=None, *, warmup_steps=0, sigma=20.0, pi=2.0, init_vars=None)
Runing training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Iterable
|
An data iterator feed training data. The data should be in the form of a tuple: (x, y). x is a dictionary with at least two keys: "image": Trainging image. "gt_locations": Point labels. Nx2 array y is a dictionary of extra labels. It can be None: for point-supervised training "gt_labels": A index-label image (H, W). For segmentation label. "gt_image_mask": A binary image (H, W). For weakly supervised training |
required |
n_steps |
int
|
Total training steps |
50000
|
validation_inteval |
Step intervals to perform validation and checkpointing. |
required | |
val_dataset |
Iterable | None
|
If not None, performing validation on this dataset. The data should be in the form of a tuple (x, y): x is a dictionary with one key: "image" y is a dictionary with two lables: "gt_bboxes" and "gt_locations" |
None
|
checkpoint_manager |
Optional[CheckpointManager]
|
If supplied, will be used to created checkpoints. A checkpoint manager can be obtained by calling:
|
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
warmup_steps |
int
|
Only used for point-supervised training. Pretraining steps, for which a large sigma values is used. This should be multiples of validation_inteval |
sigma |
float
|
Only for point-supervised training. Expected cell size |
pi |
float
|
Only for point-supervised training. Amplitude of the prior term. |
save(save_path)
Save a pickled copy of the Lacss model in the form of (module:Lacss, weights:FrozenDict). Only saves the principal model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
save_path |
Path to the pkl file |
required |