Fine-Tuning An IS2RE DPP model

Hi,

I am working on fine-tuning the IS2RE dimnet++ model with my own dataset of different adsorbate molecules.
I would like to ask if I can load the checkpoint (https://github.com/Open-Catalyst-Project/ocp/blob/main/MODELS.md) and further train/valid it with my own dataset? Or I have to train the model from the beginning?

Thanks for your help.

Cheers,

Yes of course you can certainly do that!

Hi Muhammed,

Can you show me how to do that? Or is there any example on fine-tuning a model?

I tried to create the EnergyTrainer and specifying the task, model, dataset, and optimizer. After I used load_checkpoint to EnergyTrainer, it did not do anything by train() function.

Thanks so much for the help!

Cheers,

I found the tutorial from ocp GitHub page (https://github.com/Open-Catalyst-Project/tutorial/blob/main/ocp-tutorial/fine-tuning.ipynb). Thank you for your help!

Hi Muhammed,

When I tried the example for fine-tuning the models. It doesn’t train the model after loading the dataset. The returning log is attached in the end.

I tried two different IS2RE dpp checkpoints, as well as two sets of .lmdb training/validation/testing datasets. All of these attempts failed to start training the model.

Did I set some parameters wrong? Could you help me to check the problem of it?

Cheers,

2023-07-19 19:06:38 (INFO): Project root: /codes/ocp
Unknown option: -C
usage: git [--version] [--help] [-c name=value]
           [--exec-path[=<path>]] [--html-path] [--man-path] [--info-path]
           [-p|--paginate|--no-pager] [--no-replace-objects] [--bare]
           [--git-dir=<path>] [--work-tree=<path>] [--namespace=<name>]
           <command> [<args>]
amp: false
cmd:
  checkpoint_dir: ./checkpoints/2023-07-19-19-05-52
  commit: null
  identifier: ''
  logs_dir: ./logs/tensorboard/2023-07-19-19-05-52
  print_every: 10
  results_dir: ./results/2023-07-19-19-05-52
  seed: 0
  timestamp_id: 2023-07-19-19-05-52
dataset:
  normalize_labels: true
  src: training.lmdb
  target_mean: -1.525913953781128
  target_std: 2.279365062713623
gpus: 1
logger: tensorboard
model: dimenetplusplus
model_attributes:
  cutoff: 6.0
  hidden_channels: 256
  num_after_skip: 2
  num_before_skip: 1
  num_blocks: 3
  num_output_layers: 3
  num_radial: 6
  num_spherical: 7
  otf_graph: true
  out_emb_channels: 192
  regress_forces: false
  use_pbc: true
noddp: false
optim:
  batch_size: 4
  eval_batch_size: 4
  lr_gamma: 0.1
  lr_initial: 0.0001
  lr_milestones:
  - 115082
  - 230164
  - 345246
  max_epochs: 20
  num_workers: 4
  warmup_factor: 0.2
  warmup_steps: 57541
slurm: {}
task:
  dataset: single_point_lmdb
  description: Relaxed state energy prediction from initial structure.
  labels:
  - relaxed energy
  metric: mae
  primary_metric: energy_mae
  type: regression
test_dataset:
  src: testing.lmdb
trainer: energy
val_dataset:
  src: validation.lmdb

2023-07-19 19:06:43 (INFO): Batch balancing is disabled for single GPU training.
2023-07-19 19:06:43 (INFO): Batch balancing is disabled for single GPU training.
2023-07-19 19:06:43 (INFO): Batch balancing is disabled for single GPU training.
2023-07-19 19:06:43 (INFO): Loading dataset: single_point_lmdb
2023-07-19 19:06:43 (INFO): Loading model: dimenetplusplus
2023-07-19 19:07:00 (INFO): Loaded DimeNetPlusPlusWrap with 2755462 parameters.
2023-07-19 19:07:00 (WARNING): Model gradient logging to tensorboard not yet supported.
2023-07-19 19:07:00 (INFO): Loading checkpoint from: /dpp_all_intermetallics_pari.pt
2023-07-19 19:07:00 (INFO): Total time taken: 0.15838170051574707

Hi! I wonder whether you have solved the problem. Since I have come into the same problem.

Can you try increasing the max_epochs in your config to some large number (you can always kill your job earlier when you see training curves converge). It’s likely the checkpoints have reached the total training epochs and when you try continuing training it stops immediately as a result.

1 Like

Thanks! it works!
But I still want to know how many epochs these checkpoints have been trained on. Is there any way to obtain this information?

Great!

The information you’re looking for is contained within the checkpoint.

import torch

cp = torch.load("path/to/checkpoint.pt")
cp["epoch"]

cp.keys() provides other information that may be useful if you wanted to look at.

I think I know how to do it now.
Thank you very much for your reply. I am a beginner and you and your team have helped me a lot.