-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTorchNet_impl.py
More file actions
347 lines (282 loc) · 14 KB
/
TorchNet_impl.py
File metadata and controls
347 lines (282 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
from typing import Tuple
import torch
import torch.nn as nn
from torch_geometric.nn.conv.message_passing import MessagePassing
from torch import Tensor, exp
def split_tensor(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
x: Tensor of shape (..., 3, 3)
Returns:
I: Tensor of shape (..., 3, 3)
A: Tensor of shape (..., 3, 3)
S: Tensor of shape (..., 3, 3)
Splits the input tensor into its scalar, vector, and tensor parts.
"""
assert x.shape[-2] == 3 and x.shape[-1] == 3, "Input tensor must have last two dimensions of size 3"
I = torch.diagonal(x, dim1=-2, dim2=-1).mean(dim=-1, keepdim=True).unsqueeze(-1) * torch.eye(3, device=x.device)
A = (x - x.transpose(-2, -1)) / 2
S = (x + x.transpose(-2, -1)) / 2 - I
return I, A, S
def phi(r: Tensor, rc: float = 5) -> Tensor:
"""
Args:
r: Tensor of shape (N,) distance between two atoms
rc: float (default: 5) cutoff distance
Returns:
Tensor of shape (N,)
Determines how strongly atoms influence each other, based on their distance.
"""
full_phi = .5 * torch.cos(r * (torch.pi / rc)) + .5
return full_phi * (r < rc).float()
def vector_to_skewtensor(vector):
"""
Args:
vector: Tensor of shape (...,3)
Returns:
Tensor of shape (...,3,3)
Creates a skew-symmetric tensor from a vector.
"""
zero = torch.zeros((*vector.size()[:-1],), device=vector.device, dtype=vector.dtype)
x, y, z = vector.unbind(dim=-1)
return torch.stack([zero, z, -y, -z, zero, x, y, -x, zero], dim=-1).reshape(*vector.size()[:-1], 3, 3)
def skewtensor_to_vector(skewtensor):
"""
Args:
skewtensor: Tensor of shape (...,3,3)
Returns:
Tensor of shape (...,3)
Extracts the vector from a skew-symmetric tensor.
"""
mask = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).to(skewtensor.device)
up_shift = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).to(skewtensor.device)
masked_values = skewtensor * mask
ordered_values = up_shift @ masked_values
return ordered_values.sum(dim=-1)
def e_rbf(r: Tensor, beta: Tensor|float = 1, mu: Tensor = torch.tensor(0.)) -> Tensor:
"""
Args:
r: Tensor of shape (N,) distance between two atoms
beta: float (default: 1) width of the basis function
mu: Tensor of shape (D,) center of the basis function
Returns:
Tensor of shape (N,D) with the basis function applied to each distance
Exponent
Determines how strongly atoms influence each other, based on their distance.
"""
return exp(-beta * (exp(r).unsqueeze(-1) - mu.unsqueeze(0)) ** 2)
class EquivariantMPLayer(MessagePassing):
def __init__(self, in_channels, out_channels, equivariance_class='SO(3)', rbf_dim=64):
super(EquivariantMPLayer, self).__init__(aggr='add') # "Add" aggregation.
self.eq_class = equivariance_class
self.node_dim = 0 # usually -2, for N, D, but we have N, D, 3, 3
self.rbf_dim = rbf_dim
self.rc = torch.tensor(5)
self.in_dim = in_channels
self.out_dim = out_channels
self.msg_mlp = nn.Sequential(
nn.Linear(rbf_dim, out_channels), nn.SiLU(),
nn.Linear(out_channels, out_channels * 2), nn.SiLU(),
nn.Linear(out_channels * 2, out_channels * 3), nn.SiLU(),
)
# first weighted transform
self.upd_scaler = nn.Linear(in_channels, out_channels, bias=False)
self.upd_vector = nn.Linear(in_channels, out_channels, bias=False)
self.upd_tensor = nn.Linear(in_channels, out_channels, bias=False)
# second weighted transform
self.upd_final_scaler = nn.Linear(in_channels, out_channels, bias=False)
self.upd_final_vector = nn.Linear(in_channels, out_channels, bias=False)
self.upd_final_tensor = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, x, pos, edge_index):
# x has shape [N, in_channels, 3, 3]
# pos has shape [N, 3]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x, p=pos)
def message(self, x_j, p_i, p_j):
# x_j has shape [N, in_channels, 3, 3]
# p_j has shape [N, 3]
# normalize x, the paper defines the norm as Tr(X.T @X), which is the square of the Frobenius norm
x_j = x_j / (torch.norm(x_j, dim=(-2, -1), p='fro', keepdim=True) ** 2 + 1)
# decompose x into scalar, vector, and tensor parts
I, A, S = split_tensor(x_j)
# compute r_ij
r = torch.norm(p_i - p_j, dim=-1)
# get the weights
beta = 2 / self.rbf_dim * (1 - exp(-self.rc)) # number
beta = 1 / beta**2
mu = torch.linspace(1, exp(-self.rc), self.rbf_dim).to(r.device)
w = phi(r).unsqueeze(-1) * self.msg_mlp(e_rbf(r, beta=beta, mu=mu))
f_i, f_a, f_s = w.chunk(3, dim=-1)
# w is f shape (N, 3D)
# f_x are of shape (N, D)
# compute the message
x_j = f_i.view(*f_i.size(), 1, 1) * I +\
f_a.view(*f_a.size(), 1, 1) * A +\
f_s.view(*f_s.size(), 1, 1) * S
return x_j
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels, 3, 3]
# x has shape [N, in_channels, 3, 3]
# transform previous node embeddings to have same size as aggr_out
I, A, S = split_tensor(x)
# move channels to last dimension
I = self.upd_scaler(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.upd_vector(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.upd_tensor(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
Y = I + A + S
# compute the new node embeddings
if self.eq_class == 'SO(3)':
reps = 2*torch.matmul(Y, aggr_out)
elif self.eq_class == 'O(3)':
reps = torch.matmul(Y, aggr_out) + torch.matmul(aggr_out, Y)
else:
raise AttributeError(f"Unknown equivariance class {self.eq_class}")
Iu, Au, Su = split_tensor(reps)
x_norm = torch.norm(reps, dim=(-2, -1), p='fro', keepdim=True) ** 2 + 1
Iu, Au, Su = Iu / x_norm, Au / x_norm, Su / x_norm
I = self.upd_final_scaler(Iu.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.upd_final_vector(Au.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.upd_final_tensor(Su.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
dX = dX + torch.matmul(dX, dX)
return x + dX if self.in_dim == self.out_dim else Y + dX
class TensorNetEmbedding(nn.Module):
def __init__(self, hidden_dim=64, max_z=128, rbf_dim=64):
super(TensorNetEmbedding, self).__init__()
self.rc = torch.tensor(5)
self.hid_dim = hidden_dim
self.rbf_dim = rbf_dim
self.z_emb = nn.Embedding(max_z + 1, hidden_dim)
self.z_map = nn.Linear(2 * hidden_dim, hidden_dim)
self.scalar_map = nn.Linear(rbf_dim, hidden_dim)
self.vector_map = nn.Linear(rbf_dim, hidden_dim)
self.tensor_map = nn.Linear(rbf_dim, hidden_dim)
self.node_mlp = nn.Sequential(nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, 2 * hidden_dim), nn.SiLU(),
nn.Linear(2 * hidden_dim, 3 * hidden_dim), nn.SiLU(),
)
self.node_scalar = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.node_vector = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.node_tensor = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, data):
p = data.pos
r_ij = p[data.edge_index[0]] - p[data.edge_index[1]] # shape [E, 3]
mask = torch.norm(r_ij, dim=-1) < self.rc # cutoff distance is 5 Amstrong
r_ij = r_ij[mask] / r_ij[mask].norm(dim=-1).unsqueeze(-1)
# the trace of the outer product of a vector with itself is the square of its norm
# S has dimension (E, 3, 3)
S = r_ij.unsqueeze(-1) @ r_ij.unsqueeze(-2) - torch.norm(r_ij, dim=-1, keepdim=True).unsqueeze(-1) ** 2 / 3
A = vector_to_skewtensor(r_ij)
I = torch.eye(3, device=r_ij.device).unsqueeze(0) # shape [1, 3, 3]
# Z_ij has shape
Z = self.z_emb(data.z) # shape [N, hidden_dim]
z_i = Z[data.edge_index[0]]
z_j = Z[data.edge_index[1]]
Z_ij = torch.cat([z_i, z_j], dim=-1) # shape [E, 2 * hidden_dim]
Z_ij = self.z_map(Z_ij) # shape [E, hidden_dim]
beta = 2 / self.rbf_dim * (1 - exp(-self.rc)) # number
beta = 1 / beta**2
mu = torch.linspace(1, exp(-self.rc), self.rbf_dim).to(r_ij.device)
rbf = e_rbf(r_ij.norm(dim=-1), beta=beta, mu=mu)
f_i = self.scalar_map(rbf) # shape [E, hidden_dim]
f_a = self.vector_map(rbf)
f_s = self.tensor_map(rbf)
# final edge features, shape [E, hidden_dim, 3, 3]
X = f_i.view(-1, self.hid_dim, 1, 1) * I + \
f_a.view(-1, self.hid_dim, 1, 1) * A.view(-1, 1, 3, 3) + \
f_s.view(-1, self.hid_dim, 1, 1) * S.view(-1, 1, 3, 3)
X = Z_ij.view(-1, self.hid_dim, 1, 1) * phi(r_ij.norm(dim=-1)).view(-1, 1, 1, 1) * X
# turn edge features into node features, by adding together incoming edge features
# for molecules, the edge index is usually symmetric tho
X_n = torch.zeros(data.num_nodes, self.hid_dim, 3, 3, device=r_ij.device)
X_n.index_add_(0, data.edge_index[1], X)
f_i, f_a, f_s = self.node_mlp(X_n.norm(p='fro', dim=(-2, -1))**2).chunk(3, dim=-1)
I, A, S = split_tensor(X_n)
I = f_i.view(-1, self.hid_dim, 1, 1) * self.node_scalar(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = f_a.view(-1, self.hid_dim, 1, 1) * self.node_vector(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = f_s.view(-1, self.hid_dim, 1, 1) * self.node_tensor(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
X_n = I + A + S
return X_n
class ScalarPredictionHead(nn.Module):
def __init__(self, hidden_dim):
super(ScalarPredictionHead, self).__init__()
self.predict = nn.Sequential(
nn.LayerNorm(3 * hidden_dim),
nn.Linear(3 * hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim//2), nn.SiLU(),
nn.Linear(hidden_dim//2, 1)
)
def forward(self, I, A, S, batch):
In, An, Sn = I.norm(dim=(-2, -1))**2, A.norm(dim=(-2, -1))**2, S.norm(dim=(-2, -1))**2
node_features = self.predict(torch.cat([In, An, Sn], dim=-1))
zero = torch.zeros(batch.max() + 1, device=node_features.device, dtype=node_features.dtype)
return torch.scatter_add(zero, dim=0, index=batch, src=node_features.squeeze())
class VectorPredictionHead(nn.Module):
def __init__(self, hidden_dim):
super(VectorPredictionHead, self).__init__()
self.predict = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, 1)
)
self.mask = torch.tensor([])
def forward(self, I, A, S, batch):
A = self.predict(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
node_features = skewtensor_to_vector(A)
zero = torch.zeros(batch.max() + 1, device=node_features.device, dtype=node_features.dtype)
return torch.scatter_add(zero, dim=0, index=batch.repeat(3, 1).T, src=node_features)
class TensorPredictionHead(nn.Module):
def __init__(self, hidden_dim):
super(TensorPredictionHead, self).__init__()
shared = [
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
]
self.mlp_a1 = nn.Sequential(
*shared, nn.Linear(hidden_dim, 1, bias=False)
)
self.mlp_a2 = nn.Sequential(
*shared, nn.Linear(hidden_dim, 1, bias=False)
)
def forward(self, I, A, S, batch):
A1 = skewtensor_to_vector(self.mlp_a1(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
A2 = skewtensor_to_vector(self.mlp_a2(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
# shape of A is [N, 1, 3]
outer_product = A1.unsqueeze(-1) @ A2.unsqueeze(-2)
# shape of outer_product is [N, 1, 3, 3]
node_features = (outer_product - outer_product.transpose(-2, -1)) / 2
zero = torch.zeros(batch.max() + 1, device=node_features.device, dtype=node_features.dtype)
return torch.scatter_add(zero, dim=0, index=batch.repeat(3, 3, 1).permute(2, 0, 1), src=node_features)
class TensorNet(nn.Module):
def __init__(self, hidden_dims=[64]*3, max_z=128, rbf_dim=64, equivariance_class='SO(3)', predict_type='scalar', mean=None, std=None):
super(TensorNet, self).__init__()
self.embedding = TensorNetEmbedding(hidden_dims[0], max_z, rbf_dim)
lays = []
for i, dim in enumerate(hidden_dims[1:]):
lays.append(EquivariantMPLayer(hidden_dims[i - 1], dim, equivariance_class, rbf_dim=rbf_dim))
self.layers = nn.ModuleList(lays)
if predict_type not in ['scalar', 'vector', 'tensor']:
raise AttributeError(f"Unknown predict_type {predict_type}")
elif predict_type == 'scalar':
self.predict = ScalarPredictionHead(hidden_dims[-1])
shape = (1,)
elif predict_type == 'vector':
self.predict = VectorPredictionHead(hidden_dims[-1])
shape = (3,)
elif predict_type == 'tensor':
self.predict = TensorPredictionHead(hidden_dims[-1])
shape = (3, 3)
self.mean = torch.zeros(shape) if mean is None else mean
self.std = torch.ones(shape) if std is None else std
def forward(self, data):
dat_device = data.pos.device
edge_index = torch.sparse_coo_tensor(
indices=data.edge_index,
values=torch.ones(data.edge_index.shape[1], device=dat_device),
size=(data.num_nodes, data.num_nodes),
device=dat_device).coalesce()
X = self.embedding(data)
for layer in self.layers:
X = layer(X, data.pos, edge_index)
I, A, S = split_tensor(X)
return self.predict(I, A, S, data.batch) * self.std.to(dat_device) + self.mean.to(dat_device)