Skip to content

Home

Training backends

class description
xtrain.Core No model JIT compiling, i.e., for debugging
xtrain.JIT JIT compile the model. Default strategy
xtrain.Distributed Transform the model with pmap. This allows training the model on multiple devices.

xtrain.Trainer dataclass

A general purpose FLAX model trainer. Help avoiding most of the biolerplate code when trainning with FLAX.

Attributes:

Name Type Description
model Module

A Flax module

losses LOSSES

A collection of loss function ( loss_fn(batch:Any, prediction:Any)->float ).

optimizer Optimizer

An optax optimizer

seed int | RNG

RNG seed

strategy type

a training strategy type

Example:

```
trainer = lacss.train.Trainer(my_module, my_loss_func)

train_it = trainer.train(my_dataset)

for k in range(train_steps):
    _ = next(train_it)
    if k % 1000 == 0:
        print(train_it.loss_logs)
        train_it.reset_loss_logs()
```

compute_metrics(*args, **kwargs)

A convient function to compute all metrics. See test() fucntion

Returns:

Type Description
dict

A metric dict. Keys are metric names.

predict(dataset, variables, strategy=None, method=None, **kwargs)

Create predictor iterator.

Parameters:

Name Type Description Default
dataset Iterable

An iterator or iterable to supply the input data.

required
variables dict

Model weights etc. typically get from TrainIterator

required
strategy type | None

Optionally override the default backend.

None

Returns:

Type Description

An iterator. Stepping through it will produce model predictions

test(dataset, metrics, variables, strategy=None, method=None, **kwargs)

Create test/validation iterator.

Parameters:

Name Type Description Default
dataset Iterable

An iterator or iterable to supply the testing data. The iterator should yield a tupple of (inputs, labels).

required
metrics METRICS

A list of Metric objects. They should have two functions: m.update(preds, **kwargs): preds is the model output. the remaining kwargs are content of labels. m.compute(): which should return the accumulated metric value.

required
variables dict

Model weights etc. typically get from TrainIterator

required
strategy type | None

Optionally override the default strategy.

None

Returns:

Type Description
Iterator

An iterator. Stepping through it will drive the updating of each metric obj. The iterator itself return the list of metrics.

train(dataset, *, strategy=None, rng_cols=['dropout'], init_vars=None, frozen=None, method=None, **kwargs)

Create the training iterator

Parameters:

Name Type Description Default
dataset Iterable

An iterator or iterable to supply the training data. The dataset should produce (inputs, labels, sample_weight), however both the labels and the sample_weight are optional. The inputs is either a list (not tuple) or a dict. If latter, the keys are interpreted as the names for keyword args of the model's call function.

required
strategy type | None

Optionally override the default strategy.

None
rng_cols Sequence[str]

Names of any RNG used by the model. Should be a list of strings.

['dropout']
init_vars dict | None

optional variables to initialize model

None
frozen dict | None

a bool pytree (matching model parameter tree) indicating frozen parameters.

None
**kwargs

Additional keyward args passed to the model. E.g. "training=True"

{}

Returns:

Type Description
TrainIterator

TrainIterator. Stepping through the iterator will train the model.

xtrain.TFDatasetAdapter

Convert tf.data.Dataset into a python iterable suitable for lacss.train.Trainer

my_dataset = TFDatasetAdapter(my_tf_dataset)

xtrain.TorchDataLoaderAdapter

Convert torch dataloader into a python iterable suitable for lacss.Trainer

my_dataset = TorchDataLoaderAdapter(my_torch_dataloader)

xtrain.GeneratorAdapter

Convert a python generator function to a python iterable suitable for lacss.train.Trainer with an option to prefetch data.