-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
31 lines (24 loc) · 662 Bytes
/
main.py
File metadata and controls
31 lines (24 loc) · 662 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from TorchNet_impl import TensorNet
from train_eval import run_experiment, prepare_data_qm9, extract_max_z
loaders = prepare_data_qm9(reduced_set=True)
mean = torch.tensor([data.y for data in loaders[0].dataset]).mean()
std = torch.tensor([data.y for data in loaders[0].dataset]).std()
model = TensorNet(
hidden_dims=[64] * 4,
max_z=extract_max_z(),
rbf_dim=64,
equivariance_class='SO(3)',
predict_type='scalar',
mean=mean,
std=std
)
run_experiment(
model=model,
model_name='TensorNet',
train_loader=loaders[0],
val_loader=loaders[1],
test_loader=loaders[2],
n_epochs=100,
patience=15
)