This module contains a script to train a model.
Train a Model on Data Split
timm_or_fastai_arch
(arch
:str
)
Check if arch
is a fast.ai or timm architecture and return appropriate functions.
train
(data_path
:Path
, epochs
:int
=1
, lr
:Union
[float
, str
]=0.0003
, frz
:int
=1
, pre
:int
=800
, re
:int
=256
, bs
:int
=200
, fold
:int
=4
, smooth
:bool
=False
, arch
:str
='resnet18'
, dump
:bool
=False
, log
:bool
=False
, mixup
:float
=0.0
, fp16
:bool
=False
, dls
:DataLoaders
=None
, save
:bool
=False
, pseudo
:Path
=None
)
"Train a learner on training CSV (w/folds) at data_path
.
Train Using Cross-Validation
softmax_RocAuc
(logits
, labels
)
Compute RocAuc, first taking softmax of logits
.
train_cv
(path
:"Path to data dir", epochs
:"Number of unfrozen epochs"=1
, lr
:"Initial learning rate"=0.0003
, frz
:"Number of frozen epochs"=1
, pre
:"Image presize"=(682, 1024)
, re
:"Image resize"=256
, bs
:"Batch size"=256
, smooth
:"Label smoothing?"=False
, arch
:"Architecture"='resnet18'
, dump
:"Don't train, just print model"=False
, log
:"Log w/ W&B"=False
, save
:"Save model based on RocAuc"=False
, mixup
:"Mixup (0.4 is good)"=0.0
, tta
:"Test-time augmentation"=False
, fp16
:"Mixed-precision training"=False
, do_eval
:"Evaluate model and save predictions CSV"=False
, val_fold
:"Don't do cross-validation, just do 1 fold"=None
, pseudo
:"Path to pseudo labels to train on"=None
, export
:"Export learner(s) to export_valon{fold}.pkl"=False
)
Train models using 5-fold cross-validation.