The template provides guidelines for
contributing
your work to the DeepEM Playground. In the following we provide a short overview of the
template.
Overview
Jupyter Notebooks
Each use case consists of two main notebooks:
1_Development.ipynb – Used for model development and hyperparameter
tuning.
2_Inference.ipynb – Used for running inference on trained models.
These notebooks serve as an interface between deep learning (DL) experts and electron microscopy (EM)
experts. To ensure consistency and simplify the learning process for EM researchers, the
notebooks
should follow a standardized structure.
Please update the markdown cells in the notebooks to describe your specific use case.
To assist you:
Markdown text requiring your input is highlighted in
red.
Example markdown descriptions that should be modified for your use case are marked in green.
Before submitting your use case, please remove all color formatting.
Project Structure
The deepEM/ folder contains a lightweight library for implementing your use
case.
Only modify this library if absolutely necessary.
Otherwise, use the provided modules via appropriate imports.
Your custom code implementation should be placed in the src/ folder.
All custom implementations need to implement the Inferencer.py, Model.py, and
ModelTrainer.py based on their corresponding modules of the deepEM library.
A simple example is provided by implementing this
tutorial.
Adapt this code to implement your use case.
For library documentation and available functions, please see below.
Model Configuration
The DeepEM library manages model parameters through a configuration file:
configs/parameters.json – Defines all hyperparameters for
training.
EM experts can use the API in 1_Development.ipynb to fine-tune
hyperparameters
via grid search.
Tunable vs. Non-Tunable Parameters
Tunable hyperparameters should be clearly documented for adjustment.
Non-tunable hyperparameters are not accessible to EM experts but should still be
included
with explanations to improve their understanding of the underlying method.
All parameters—both tunable and non-tunable—must be well-documented.
For detailed documentation, see configs/README.md.
DeepEM Documentation
The DeepEM library provides a simple, pytorch based framework for the implementation of training,
tuning and applying
deep learning models. It provides following modules:
A simple Logger class for monitoring training and tuning of the model
A module for automatic hyperparameter tuning (ModelTuner).
An AbstractModelTrainer which implements basic learning concepts like early
stopping or model checkpointing.
A AbstractModel to implement the deep learning model, based on the
torch.nn.Module() class.
An AbstractInfrence module to do inference on single or multiple files.
This approach not only simplifies implementation for DL experts but also establishes a standardized workflow,
reducing the learning curve for EM researchers. It enables them to modify the application of the use case simply
by
changing the training data, without requiring any code changes.
Note To contribute your work to the DeepEM playground, you need to implement the
@abstractmethod within all abstract classes. All other classes function as helper classes and
should only be altered if absolutely nessecary.
Legend
Methods highlighted in this color are abstract
methods that must be implemented by the DL specialist. All other methods are helper methods and can be
overwritten if nessecary.
The AbstractModel class is a base class for deep learning models, extending
torch.nn.Module. This class provides methods for resetting model parameters,
performing the forward pass, and making predictions. It is intended to be inherited by specific
model classes to define the architecture and training logic.
Class Overview
AbstractModel serves as a foundational model class that defines essential
methods for deep learning models. These include:
reset_model_parameters_recursive - Recursively reset model parameters.
forward - Defines the forward pass, which must be implemented by subclasses.
predict - Makes predictions using the model.
Constructor
__init__()
Initializes the AbstractModel class, which extends torch.nn.Module.
This constructor serves as a base class for all deep learning models, providing methods for
resetting model parameters (i.e. in between multiple model runs during hyperparameter search),
performing
the forward pass, and making predictions.
See pytorch torch.nn.Moduledocumentation for details.
@abstractmethod
forward(x)
The forward pass of the model.
Subclasses must implement this method, as each model will have its own architecture
and forward pass logic.
Args:
x (torch.Tensor): The input tensor to be passed through the model.
Returns:
torch.Tensor: The output tensor generated by the model.
Example usage: output = model(x)
@abstractmethod
predict(x)
Makes predictions using the model.
This method sets the model into evaluation mode and performs inference without computing gradients
(using torch.no_grad()).
Args:
x (torch.Tensor): The input tensor to generate a prediction.
Returns:
torch.Tensor: The predicted output from the model.
Example usage: predictions = model.predict(x)
reset_model_parameters_recursive()
Recursively resets the parameters of all layers in the PyTorch model.
This method traverses the model and resets the parameters of any layers that have a
reset_parameters method, commonly used in layers like nn.Conv2d and
nn.Linear.
Example usage: model.reset_model_parameters_recursive()
The AbstractModelTrainer class is an abstract base class designed to facilitate the
training, validation,
and testing of deep learning models. It provides a structured workflow for model training, including dataset
handling,
logging, checkpointing, and early stopping mechanisms. It allows for automated hyperparameter tuning by
leveraging deepEM.ModelTuner.ModelTuner
Class Overview
AbstractModelTrainer manages the entire training pipeline, ensuring modularity and
flexibility.
It is intended to be subclassed, requiring concrete implementations for setting up the model, datasets,
optimizer, and scheduler. Key functionalities include:
save_checkpoint(epoch, val_loss) - Saves the best model checkpoint based on validation
loss.
load_checkpoint(checkpoint_path) - Loads a model checkpoint, including optimizer and
scheduler states.
train_epoch(epoch) - Executes a single epoch of training and validation.
test() - Runs the test loop after training.
fit() - Manages the full training process, including early stopping and logging.
This class supports GPU acceleration, integrates with a the deepEM.Logger.Logger for tracking
training progress,
and provides dataset loaders for training, validation, and testing. Subclasses must define the specific
architecture and training behavior of the model.
Initializes the trainer class for training, validating, and testing models.
Args:
data_path (str): Path to the dataset used for training, validation, and testing.
logger (Logger): Logger instance for logging events and training progress.
resume_from_checkpoint (str, optional): Path to a checkpoint to resume training from.
Defaults to None.
self.data_path [str] - Path to the training data.
self.logger [deepEM.logger.Logger] - Logger instance for training and
evaluation.
self.resume_from_checkpoint [Optional[str]] - Path to a checkpoint file
for resuming training. Defaults to None.
self.parameter [dict] - Training parameters defined in
configs/parameters.json, containing both tunable and non-tunable parameters. Updated
before each run during sweeps.
self.device [str] - Specifies the computing device: "cuda" if
a GPU is available, otherwise "cpu".
self.model [deepEM.Model.AbstractModel] - The model instance used for
training.
self.train_subset [float] - Fraction (between 0 and 1) of the training
data used for training in hyperparameter sweeps. Defined in configs/parameters.json.
self.reduce_epochs [float] - Fraction (between 0 and 1) of total epochs
used for training in hyperparameter sweeps. Defined in configs/parameters.json.
self.train_loader [torch.utils.data.DataLoader] - DataLoader for training
data.
self.val_loader [torch.utils.data.DataLoader] - DataLoader for validation
data.
self.test_loader [torch.utils.data.DataLoader] - DataLoader for test data.
self.val_vis_loader [torch.utils.data.DataLoader] - DataLoader for
visualizing a subset of predictions on the validation set.
self.test_vis_loader [torch.utils.data.DataLoader] - DataLoader for
visualizing a subset of predictions on the test set.
self.optimizer [torch.optim.Optimizer] - Optimizer for training.
self.scheduler [torch.optim.lr_scheduler] - Learning rate scheduler for
training.
self.best_val_loss [float] - Stores the best validation loss for early
stopping.
self.patience_counter [int] - Counter that increases when validation loss
does not improve after an epoch. Used for early stopping.
self.best_model_wts [state_dict] - Stores the model's best-performing
state dictionary.
self.num_epochs [int] - Number of training epochs, calculated as
self.parameter["epochs"] * self.reduce_epochs.
self.validation_interval [int] - Interval (in training iterations) at
which validation is performed, computed as
max(1, self.parameter["validation_interval"] * self.reduce_epochs).
@abstractmethod
setup_model()
Setup and return the model for training, validation, and testing. Make sure to move the model to self.device.
This method must be implemented by the DL expert.
Returns:
deepEM.Model.AbstractModel: The initialized model ready for training, validation,
and
testing.
@abstractmethod
setup_datasets()
Setup and return the datasets for training, validation, and testing.
The data path provided by the EM specialist can be accessed via self.data_path.
Returns:
torch.utils.data.Dataset: The training dataset.
torch.utils.data.Dataset: The validation dataset.
torch.utils.data.Dataset: The test dataset.
@abstractmethod
setup_optimizer()
Setup and return the optimizer and learning rate scheduler.
Returns:
torch.optim.Optimizer: The optimizer for the model.
torch.optim.lr_scheduler._LRScheduler: The learning rate scheduler.
@abstractmethod
compute_loss(outputs, targets)
Compute the loss for a batch.
Args:
outputs (torch.Tensor): Model outputs.
targets (torch.Tensor): Ground truth labels.
Returns:
torch.Tensor: Computed loss.
@abstractmethod
train_step(batch_idx, batch)
Perform one training step.
Args:
batch_idx: Index of the current batch.
batch (tuple): A batch of data i.e. (inputs, targets).
Returns:
torch.Tensor: The loss for this batch.
@abstractmethod
val_step(batch_idx, batch)
Perform one validation step.
Args:
batch_idx: Index of the current batch.
batch: A batch of data defined by the dataset implementation.
Returns:
torch.Tensor: The loss for this batch.
dict: Dictionary of metrics for this batch (e.g., accuracy, F1 score, etc.).
@abstractmethod
test_step(batch_idx, batch)
Perform one test step.
Args:
batch_idx: Index of the current batch.
batch: A batch of data defined by the dataset implementation.
Returns:
torch.Tensor: The loss for this batch.
dict: Dictionary of metrics for this batch (e.g., accuracy, F1 score, etc.).
@abstractmethod
visualize(batch)
Visualizes the model's input and output of a single batch and returns them as PIL images.
Args:
batch: A batch of data defined by the dataset implementation.
Returns:
List[PIL.Image]: List of visualizations for the batch data.
@abstractmethod
inference_metadata()
Returns metadata needed for inference (such as class names) as a dictionary.
This metadata will be saved along with model weights in training checkpoints.
Prepares the training pipeline by setting up the model, datasets, dataloaders, optimizer, scheduler,
and other configurations.
Args:
config (dict, optional): Dictionary of hyperparameters to override the defaults.
Defaults to None.
train_subset (float, optional): Fraction of the training dataset to use for quick
hyperparameter tuning. Defaults to None.
reduce_epochs (float, optional): Fraction of epochs to use for quick hyperparameter
tuning. Defaults to None.
num_epochs (float, optional): Explicitly set number of epochs for training. If
defined, will overwrite self.parameter["epoch"]. Defaults to None.
set_parameters (bool, optional): Whether to set the hyperparameters from the
provided config. Defaults to True.
set_epochs()
Sets the number of epochs and validation interval based on the configuration. The number of epochs
and
validation interval may be reduced if the self.reduce_epochs parameter is set.
subsample_trainingdata(dataset)
Subsamples the training dataset if a subset fraction is provided (self.train_subset is set).
Args:
dataset (torch.utils.data.Dataset): The training dataset.
Returns:
torch.utils.data.Subset: A subset of the dataset if train_subset is
specified, else the full dataset.
Sets up and returns dataloaders for visualizing a subset of validation and test datasets.
This method subsamples the datasets to contain a fixed number of images specified by
self.parameter["images_to_visualize"]. It should be overridden for imbalanced data.
Args:
val_dataset (torch.utils.data.Dataset): The validation dataset.
test_dataset (torch.utils.data.Dataset): The test dataset.
Returns:
torch.utils.data.DataLoader: Dataloader for visualizing a subset of the validation
dataset.
torch.utils.data.DataLoader: Dataloader for visualizing a subset of the test
dataset.
AbstractInference is an abstract base class for performing model inference. It provides
methods
for model loading, inference execution, and result storage. It is used within the
2_Inference.ipynb
Logger is a class that provides logging functionality, including checkpoint saving,
hyperparameter tracking, and resource monitoring. It will print to the console as well as save log files to
the
system.
Class Overview
For each time running the .ipynb, it will create a log directory at
logs/{datafolder}-{currentdatetime}. Within this directory, it will create one subfolder for
each
run. Hyperparameter sweep runs are named Sweep_{idx}. The training run will be named
TrainingRun. Finally, evaluations will be saved at subfolder Evaluate.
Each training run subfolder will have following directories:
checkpoints to store the
latest_model.pth as well as the best_model.pth.
plots to store the training and validation curves.
samples to store qualitative visualizations during validation.
Additionally, each training run subfolder stores a hyperparameters.json to store the used
hyperparameters, as well as log.txt to save the loggers output to a file.
Constructor
__init__(data_path: str) -> None
Initializes the Logger and creates a timestamped log directory.
data_path (str): The directory of the training data.
self.data_path [str] - Path to the training data. Also used to store the
best sweep parameters at self.data_path/Sweep_Parameters.
self.root_dir [str] - Root directory for logging, typically set to
logs/timestamp. This directory can contain logs from multiple training runs.
self.log_dir [str] - Directory for storing logs of a specific training
run. It is a subfolder of self.root_dir and also includes evaluation logs.
self.checkpoints_dir [str] - Subdirectory ("checkpoints") within
self.log_dir used for saving model checkpoints.
self.plots_dir [str] - Subdirectory ("plots") within
self.log_dir used for storing training and validation curve plots.
self.samples_dir [str] - Subdirectory ("samples") within
self.log_dir used for saving qualitative samples from validation or testing.
self.logger [logging.Logger] - Instance of the Python logging
module for handling logs.
Saves the best hyperparameters found during a sweep based on the validation loss to a JSON file
within
the directory of the training
data (self.data_path) at
{self.data_path}/Sweep_Parameters/best_sweep_parameters.json. If there is already an
existing file, will only override, if the current validation loss is lower than the saved.
load_best_sweep() -> dict | None
Loads the best hyperparameters from
{self.data_path}/Sweep_Parameters/best_sweep_parameters.json found during a sweep.
Returns
None if file was not found.
log_hyperparameters(hyperparams: dict) -> None
Saves the current training hyperparameters to a JSON file at
{self.log_dir}/hyperparameters.json.
Saves test results including test loss and other metrics. This is usually saved in the logging
directory
at subfolder Evaluate within the file test_results.txt.
append_test_results(metrics: dict) -> None
Appends the test results including by other metrics. This is usually saved in the logging
directory
at subfolder Evaluate within the file test_results.txt.
get_most_recent_logs() -> dict
Retrieves the most recent log directories for each dataname in the folder. This method is used when
the
EM expert does not define a model path for evaluation. Then the most recent logs will be used. It
retruns a dictionary where keys are datanames and values are paths to the most recent log
directories.
Resource Monitoring Methods
get_resource_usage() -> tuple(dict,str)
Retrieves current system and GPU resource usage. Returns a tuple containing a dictonary with resource
usage as well as a formated string.
log_resource_usage() -> None
Logs system resource usage to a file at {self.log_dir}/resource_usage.log.
The ModelTuner class provides a framework for performing hyperparameter tuning on machine
learning models using grid search. The class can be extended to implement different search methods (like
random,
bayesian).
Class Overview
The ModelTuner class automates hyperparameter tuning on the ModelTrainer class.
It provides a simple API between DL specialists and the EM specialists (users).
While DL experts define their model parameters within the configs/parameters.json,
ipywidgets are used to allow code independent input by the EM specialists.
The class supports loading and updating configuration files, creating interactive widgets for parameter
tuning,
performing grid search, and logging the best hyperparameters based on validation performance.
It helps streamline the hyperparameter tuning process, optimizing model performance efficiently.
Constructor
__init__(model_trainer, data_path, logger)
Initializes the ModelTuner class with the given model trainer, data path, and logger. It also loads the
configuration and sets hyperparameter tuning options.
Args:
model_trainer (ModelTrainer): The trainer used to train the model during the
hyperparameter
sweep.
logger (Logger): The logger used for recording sweep progress and results.
self.model_trainer [deepem.ModelTrainer.AbstractModelTrainer] - Model
trainer used for hyperparameter search.
self.logger [deepem.Logger.Logger] - Logger instance for recording sweep
progress and results.
self.config [dict] - Dictionary containing hyperparameters for tuning,
loaded from configs/parameters.json.
self.trainsubset [float] - Fraction (between 0 and 1) of the training data
used during hyperparameter sweeps. Defined in configs/parameters.json.
self.reduce_epochs [float] - Fraction (between 0 and 1) of the total
epochs used for training in hyperparameter sweeps. Defined in configs/parameters.json.
self.method [str] - Hyperparameter tuning method. Currently, only
"grid" search is supported.
edit_hyperparameters()
Displays the best hyperparameters from the previous sweep or default values, and allows the user to
modify them via a widget interface.
Returns:
widgets.VBox: A VBox widget containing the editable hyperparameter values.
Example usage: widget = model_tuner.edit_hyperparameters()
update_hyperparameters(widget_box)
Updates the hyperparameter configuration based on the user input in the provided widget box.
Args:
widget_box (widgets.VBox): The widget containing the user-modifiable
hyperparameters.
Returns:
dict: A dictionary of the updated hyperparameters.
Example usage: updated_params = model_tuner.update_hyperparameters(widget_box)
create_hyperparameter_widgets()
Generates the interactive widgets for tuning hyperparameters, displaying both fixed and adjustable
parameters.
Returns:
widgets.VBox: A VBox widget containing the hyperparameter adjustment interface.
Example usage: widgets_box = model_tuner.create_hyperparameter_widgets()
update_config(widget_box)
Updates the configuration based on the widget input and returns the modified config.
Args:
widget_box (widgets.VBox): The widget box containing the modified hyperparameter
values.
Returns:
dict: The updated configuration dictionary.
Example usage: updated_config = model_tuner.update_config(widget_box)
load_config(config_file)
Loads the configuration settings from a JSON file.
Args:
config_file (str): The path to the configuration JSON file.
Returns:
dict: The loaded configuration settings.
Example usage: config = model_tuner.load_config(config_file)
get_default_params()
Extracts and returns the default hyperparameters from the configuration.
Returns:
dict: A dictionary containing the default hyperparameter values.
Example usage: default_params = model_tuner.get_default_params()
prepare_grid_search_space()
Prepares the search space for grid search based on the defined hyperparameters.
Returns:
dict: A dictionary containing the possible values for each hyperparameter.
Example usage: search_space = model_tuner.prepare_grid_search_space()
tune_grid()
Performs a grid search over the hyperparameter search space, evaluating model performance for each
combination of parameters.
Returns:
dict, float: The best hyperparameters and the corresponding validation loss.
Example usage: best_params, best_loss = model_tuner.tune_grid()
tune()
Performs the hyperparameter tuning based on the selected method (grid, random, or bayes).
Returns:
dict: The best hyperparameters determined by the tuning method.