Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Error when loading a saved TFTModel checkpoint #2369

Closed
chododom opened this issue May 4, 2024 · 2 comments
Closed

[BUG] Error when loading a saved TFTModel checkpoint #2369

chododom opened this issue May 4, 2024 · 2 comments
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@chododom
Copy link

chododom commented May 4, 2024

Describe the bug
I have trained a TFT model with Quantile loss and I cannot load the checkpoint that was saved during best epoch for evaluation.

To Reproduce

def build_tft(scaler):
    tft = TFTModel(
        input_chunk_length=48,
        output_chunk_length=24,
        hidden_size=8,
        lstm_layers=3,
        num_attention_heads=4,
        full_attention=False,
        feed_forward='GatedResidualNetwork',
        dropout=0.15,
        loss_fn=QuantileRegression(quantiles=[0.1, 0.5, 0.9]),
        likelihood=QuantileRegression(quantiles=[0.1, 0.5, 0.9]),
        use_static_covariates=True,
        torch_metrics=MetricCollection(
            metrics=[MeanAbsoluteError(), MeanAbsolutePercentageError()]
        ),
        batch_size=1024,
        n_epochs=15,
        model_name=f"tft_{8}_{3}_{4}_{0.00003}",
        work_dir=os.path.join(args.save_dir, 'tft'),
        log_tensorboard=True,
        save_checkpoints=True,
        optimizer_kwargs={"lr": 0.00003},
        random_state=42,
        add_encoders={
            'cyclic': {'future': ['month']},
            'datetime_attribute': {'future': ['hour']},
            'transformer': scaler
        }
    )

    return tft

# scaler is a loaded darts.dataprocessing.transformers.Scaler saved during preprocessing
model = build_tft(scaler)
model.fit(trian, val)

I let the model train like so and then I want to load the best checkpoint for evaluation:

model = TFTModel.load_from_checkpoint(model_name='tft_8_3_4_0.0003', work_dir='training_data/tft', best=True)

but that leads to the following error:

TypeError: cannot assign 'darts.utils.likelihood_models.QuantileRegression' as child module 'criterion' (torch.nn.Module or None expected)

Can anyone please suggest a solution?

System (please complete the following information):

  • Python version: [e.g. 3.10.14]
  • darts version [e.g. 0.28.0]
@chododom chododom added bug Something isn't working triage Issue waiting for triaging labels May 4, 2024
@chododom chododom changed the title [BUG] [BUG] Error when loading a saved TFTModel checkpoint May 4, 2024
@chododom
Copy link
Author

chododom commented May 4, 2024

I figured out what is causing the error - when building the model, I use a custom quantile range of [0.1, 0.5, 0.9]. When I leave this parameter out and use the default QuantileRegression, the loader works.

@dennisbader
Copy link
Collaborator

Hi @chododom, loss_fn must be None when using a likelihood and vice versa.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

No branches or pull requests

2 participants