Numpy计算近邻表时间对比
技術背景
所謂的近鄰表求解,就是給定N個原子的體系,找出滿足cutoff要求的每一對原子。在前面的幾篇博客中,我們分別介紹過CUDA近鄰表計算與JAX-MD關于格點法求解近鄰表的實現。雖然我們從理論上可以知道,用格點法求解近鄰表,在復雜度上肯定是要優于傳統的算法。本文主要從Python代碼的實現上來具體測試一下二者的速度差異,這里使用的硬件還是CPU。
算法解析
若一對原子A和B滿足下述條件,則稱A、B為一對近鄰原子:
\[|\textbf{r}_A-\textbf{r}_B|\leq cutoff \]傳統的求解方法,就是把所有原子間距都計算一遍,然后對每個原子的近鄰原子進行排序,最終按照給定的cutoff截斷值確定相關的近鄰原子。在Python中的實現,因為有numpy這樣的強力工具,我們在計算原子兩兩間距時,只需要對一組維度為(N,D)的原子坐標進行擴維,分別變成(1,N,D)和(N,1,D)大小的原子坐標。然后將二者相減,計算過程中會自動廣播(Broadcast)成(N,N,D)和(N,N,D)的兩個數組進行計算。對得到的結果做一個Norm,就可以得到維度為(N,N)的兩兩間距矩陣。該算法的計算復雜度為O(N^2)。
相對高效的一種求解方案是將原子坐標所在的空間劃分成眾多的小區域,通常我們設定這些小區域為邊長等于cutoff的小正方體。這種設定有一個好處是,我們可以確定每一個正方體的近鄰原子,一定在最靠近其周邊的26個小正方體區域內。這樣一來,我們就不需要去計算全局的兩兩間距,只需要計算單個小正方體內(假定有M個原子)的兩兩間距(M,M),以及單個正方體與周邊正方體內原子的配對間距(M,26M)。之所以這樣分開計算,是為了減少原子跟自身間距的這一項重復計算。那么對于整個空間的原子,就需要計算(N,27M)這么多次的原子間距,是一個復雜度為O(NlogN)的算法。
Numpy代碼實現
這里我們基于Python中的numpy框架來實現這兩個不同的計算近鄰表的算法。其實當我們使用numpy來進行計算的時候,應當盡可能的避免循環體的使用。但是這里僅演示兩種算法的差異性,因此在實現格點法的時候偷了點懶,用了兩個for循環,感興趣的童鞋可以自行優化。
import time
from itertools import chain
from operator import itemgetter
import numpy as np
# 在格點法中,為了避免重復計算,我們可以僅計算一半的近鄰格點中的原子間距
NEIGHBOUR_GRID = np.array([
[-1, 1, 0],
[-1, -1, 1],
[-1, 0, 1],
[-1, 1, 1],
[ 0, -1, 1],
[ 0, 0, 1],
[ 0, 1, 0],
[ 0, 1, 1],
[ 1, -1, 1],
[ 1, 0, 0],
[ 1, 0, 1],
[ 1, 1, 0],
[ 1, 1, 1]], np.int32)
# 原始的兩兩間距計算方法,需要排序
def get_neighbours_by_dist(crd, cutoff):
large_dis = np.tril(np.ones((crd.shape[0], crd.shape[0])) * 999)
# (N, N)
dis = np.linalg.norm(crd[None] - crd[:, None], axis=-1) + large_dis
# (N, M)
neigh = np.argsort(dis, axis=-1)
# (N, M)
cut = np.take_along_axis(dis, neigh, axis=1)
# (2, P)
pairs = np.where(cut <= cutoff)
# (P, )
pairs_id0 = pairs[0]
pairs_id1 = neigh[pairs]
# (P, 2)
sort_args = np.argsort(pairs_id0)
return np.hstack((pairs_id0[..., None], pairs_id1[..., None]))[sort_args]
# 格點法計算近鄰表,先分格點,然后分兩個模塊計算單格點內原子間距,和中心格點-周邊格點內的原子間距
def get_neighbours_by_grid(crd, cutoff):
# (D, )
min_xyz = np.min(crd, axis=0)
max_xyz = np.max(crd, axis=0)
space = max_xyz - min_xyz
grids = np.ceil(space / cutoff).astype(np.int32)
num_grids = np.product(grids)
buffer = (grids * cutoff - space) / 2
start_crd = min_xyz - buffer
# (N, D)
grid_id = ((crd - start_crd) // cutoff).astype(np.int32)
grid_coe = np.array([1, grids[0], grids[1]], np.int32)
# (N, )
grid_id_1d = np.sum(grid_id * grid_coe, axis=-1).astype(np.int32)
# (N, 2)
grid_id_dict = np.ndenumerate(grid_id_1d)
# (G, *)
grid_dict = dict.fromkeys(range(num_grids), ())
for index, value in grid_id_dict:
grid_dict[value] += index
neighbour_grid = (NEIGHBOUR_GRID * grid_coe).sum(axis=-1).astype(np.int32)
neighbour_pairs = []
for i in range(num_grids):
if grid_dict[i]:
keeps = np.where((neighbour_grid + i < num_grids) & (neighbour_grid + i >= 0))[0]
neighbour_grid_keep = neighbour_grid[keeps] + i
grid_atoms = np.array(list(grid_dict[i]), np.int32)
try:
grid_neighbours = np.array(list(chain(*itemgetter(*neighbour_grid_keep)(grid_dict))), np.int32)
except TypeError:
if neighbour_grid_keep.size == 0:
grid_neighbours = np.array([], np.int32)
else:
grid_neighbours = np.array(list(itemgetter(*neighbour_grid_keep)(grid_dict)), np.int32)
grid_crds = crd[grid_atoms]
grid_neighbour_crds = crd[grid_neighbours]
large_dis = np.tril(np.ones((grid_crds.shape[0], grid_crds.shape[0])) * 999)
# 單格點內部原子間距
grid_dis = np.linalg.norm(grid_crds[None] - grid_crds[:, None], axis=-1) + large_dis
grid_pairs = np.argsort(grid_dis, axis=-1)
grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
pairs = np.where(grid_cut <= cutoff)
pairs_id0 = grid_atoms[pairs[0]]
pairs_id1 = grid_atoms[grid_pairs[pairs]]
neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
# 中心格點-周邊格點內原子間距
grid_dis = np.linalg.norm(grid_crds[:, None] - grid_neighbour_crds[None], axis=-1)
grid_pairs = np.argsort(grid_dis, axis=-1)
grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
pairs = np.where(grid_cut <= cutoff)
pairs_id0 = grid_atoms[pairs[0]]
pairs_id1 = grid_neighbours[grid_pairs[pairs]]
neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
neighbour_pairs = np.sort(np.array(neighbour_pairs), axis=-1)
sort_args = np.argsort(neighbour_pairs[:, 0])
return neighbour_pairs[sort_args]
# 時間測算函數
def benchmark(N, cutoff=0.3, D=3):
crd = np.random.random((N, D)).astype(np.float32) * np.array([3., 4., 5.], np.float32)
# Solution 1
time0 = time.time()
neighbours_1 = get_neighbours_by_dist(crd, cutoff)
time1 = time.time()
record_1 = time1 - time0
# Solution 2
time0 = time.time()
neighbours_2 = get_neighbours_by_grid(crd, cutoff)
time1 = time.time()
record_2 = time1 - time0
for pair in neighbours_1:
if (np.isin(neighbours_2, pair).sum(axis=-1) < 2).all():
print (pair)
assert neighbours_1.shape == neighbours_2.shape
return record_1, record_2
# 繪圖主函數
if __name__ == '__main__':
import matplotlib.pyplot as plt
sizes = range(1000, 10000, 1000)
time_dis = []
time_grid = []
for size in sizes:
print (size)
times = benchmark(size)
time_dis.append(times[0])
time_grid.append(times[1])
plt.figure()
plt.title('Neighbour List Calculation Time')
plt.plot(sizes, time_dis, color='black', label='Full Connect')
plt.plot(sizes, time_grid, color='blue', label='Cell List')
plt.xlabel('Size')
plt.ylabel('Time/s')
plt.legend()
plt.grid()
plt.show()
上述代碼的運行結果如下圖所示:
其實因為格點法中使用了for循環的問題,函數效率并不高。因此在體系非常小的場景下(比如只有幾十個原子的體系),本文用到的格點法代碼效率并不如計算所有的原子兩兩間距。但是畢竟格點法的復雜度較低,因此在運行過程中隨著體系的增長,格點法的優勢也越來越大。
近鄰表計算與分子動力學模擬
在分子動力學模擬中計算長程相互作用時,會經常使用到近鄰表。如果要在GPU上實現格點近鄰算法,有可能會遇到這樣的一些問題:
- GPU更加擅長處理靜態Shape的張量,因此往往會使用一個
最大近鄰數,對每一個原子的近鄰原子標號進行限制,一般不允許滿足cutoff的近鄰原子數超過最大近鄰數,否則這個cutoff就失去意義了。而如果單個原子的近鄰原子數量低于最大近鄰數,這時候就會用一個沒有意義的數對剩下分配好的張量空間進行填充(Padding),這樣一來會帶來很多不必要的計算。 - 在運行分子動力學模擬的過程中,體系原子的坐標在不斷的變化,近鄰表也會隨之變化,而此時的最大近鄰數有可能無法存儲完整的cutoff內的原子。
總結概要
本文介紹了在Python的numpy框架下計算近鄰表的兩種不同算法的原理以及復雜度,另有分別對應的兩種代碼實現。在實際使用中,我們更偏向于第二種算法的使用。因為對于第一種算法來說,哪怕是一個10000個原子的小體系,如果要計算兩兩間距,也會變成10000*10000這么大的一個張量的運算。可想而知,這樣計算的效率肯定是比較低下的。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/cell-list.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
請博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
總結
以上是生活随笔為你收集整理的Numpy计算近邻表时间对比的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 网站优化之favicon.ico
- 下一篇: 如果有人在你的论坛、博客,乱留言、乱回复