Home
Training backends
class | description |
---|---|
xtrain.Core | No model JIT compiling, i.e., for debugging |
xtrain.JIT | JIT compile the model. Default strategy |
xtrain.Distributed | Transform the model with pmap. This allows training the model on multiple devices. |
xtrain.Trainer
dataclass
A general purpose FLAX model trainer. Help avoiding most of the biolerplate code when trainning with FLAX.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
A Flax module |
losses |
LOSSES
|
A collection of loss function ( loss_fn(batch:Any, prediction:Any)->float ). |
optimizer |
Optimizer
|
An optax optimizer |
seed |
int | RNG
|
RNG seed |
strategy |
type
|
a training strategy type |
Example:
```
trainer = lacss.train.Trainer(my_module, my_loss_func)
train_it = trainer.train(my_dataset)
for k in range(train_steps):
_ = next(train_it)
if k % 1000 == 0:
print(train_it.loss_logs)
train_it.reset_loss_logs()
```
compute_metrics(*args, **kwargs)
A convient function to compute all metrics. See test() fucntion
Returns:
Type | Description |
---|---|
dict
|
A metric dict. Keys are metric names. |
predict(dataset, variables, strategy=None, method=None, **kwargs)
Create predictor iterator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Iterable
|
An iterator or iterable to supply the input data. |
required |
variables
|
dict
|
Model weights etc. typically get from TrainIterator |
required |
strategy
|
type | None
|
Optionally override the default backend. |
None
|
Returns:
Type | Description |
---|---|
An iterator. Stepping through it will produce model predictions |
test(dataset, metrics, variables, strategy=None, method=None, **kwargs)
Create test/validation iterator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Iterable
|
An iterator or iterable to supply the testing data. The iterator should yield a tupple of (inputs, labels). |
required |
metrics
|
METRICS
|
A list of Metric objects. They should have two functions: m.update(preds, **kwargs): preds is the model output. the remaining kwargs are content of labels. m.compute(): which should return the accumulated metric value. |
required |
variables
|
dict
|
Model weights etc. typically get from TrainIterator |
required |
strategy
|
type | None
|
Optionally override the default strategy. |
None
|
Returns:
Type | Description |
---|---|
Iterator
|
An iterator. Stepping through it will drive the updating of each metric obj. The iterator itself return the list of metrics. |
train(dataset, *, strategy=None, rng_cols=['dropout'], init_vars=None, frozen=None, method=None, **kwargs)
Create the training iterator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Iterable
|
An iterator or iterable to supply the training data.
The dataset should produce |
required |
strategy
|
type | None
|
Optionally override the default strategy. |
None
|
rng_cols
|
Sequence[str]
|
Names of any RNG used by the model. Should be a list of strings. |
['dropout']
|
init_vars
|
dict | None
|
optional variables to initialize model |
None
|
frozen
|
dict | None
|
a bool pytree (matching model parameter tree) indicating frozen parameters. |
None
|
**kwargs
|
Additional keyward args passed to the model. E.g. "training=True" |
{}
|
Returns:
Type | Description |
---|---|
TrainIterator
|
TrainIterator. Stepping through the iterator will train the model. |
xtrain.TFDatasetAdapter
Convert tf.data.Dataset
into a python iterable suitable for lacss.train.Trainer
my_dataset = TFDatasetAdapter(my_tf_dataset)
xtrain.TorchDataLoaderAdapter
Convert torch dataloader into a python iterable suitable for lacss.Trainer
my_dataset = TorchDataLoaderAdapter(my_torch_dataloader)
xtrain.GeneratorAdapter
Convert a python generator function to a python iterable suitable for lacss.train.Trainer with an option to prefetch data.