1*da0073e9SAndroid Build Coastguard Workerimport math 2*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom scipy import sparse 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerdef to_coo_scipy(x): 10*da0073e9SAndroid Build Coastguard Worker indices_1 = x._indices().numpy() 11*da0073e9SAndroid Build Coastguard Worker values_1 = x._values().numpy() 12*da0073e9SAndroid Build Coastguard Worker return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), shape=x.shape) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef sparse_grad_output(a, b): 16*da0073e9SAndroid Build Coastguard Worker c = torch.sparse.mm(a, b) 17*da0073e9SAndroid Build Coastguard Worker if c.is_sparse: 18*da0073e9SAndroid Build Coastguard Worker c2 = torch.rand_like(c.to_dense()) 19*da0073e9SAndroid Build Coastguard Worker return c2.sparse_mask(c.coalesce()) 20*da0073e9SAndroid Build Coastguard Worker else: 21*da0073e9SAndroid Build Coastguard Worker return torch.rand_like(c) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef read_matrix_params(path): 25*da0073e9SAndroid Build Coastguard Worker with open(path) as file: 26*da0073e9SAndroid Build Coastguard Worker line = file.readline() 27*da0073e9SAndroid Build Coastguard Worker nrows, ncols, nnz = (int(el) for el in line.split(", ")) 28*da0073e9SAndroid Build Coastguard Worker return (nrows, ncols), nnz 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerdef csr_to_coo(indices, indptr, shape): 32*da0073e9SAndroid Build Coastguard Worker n_rows, n_cols = shape 33*da0073e9SAndroid Build Coastguard Worker cols = indices 34*da0073e9SAndroid Build Coastguard Worker rows = [0] * len(cols) 35*da0073e9SAndroid Build Coastguard Worker for i in range(n_rows): 36*da0073e9SAndroid Build Coastguard Worker for j in range(indptr[i], indptr[i + 1]): 37*da0073e9SAndroid Build Coastguard Worker rows[j] = i 38*da0073e9SAndroid Build Coastguard Worker return torch.tensor([rows, cols], dtype=torch.long) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef load_sparse_matrix(path, device): 42*da0073e9SAndroid Build Coastguard Worker with open(path) as file: 43*da0073e9SAndroid Build Coastguard Worker nrows, ncols, nnz = (int(el) for el in file.readline().split(", ")) 44*da0073e9SAndroid Build Coastguard Worker index_pointers = (int(el) for el in file.readline().split()) 45*da0073e9SAndroid Build Coastguard Worker indices = (int(el) for el in file.readline().split()) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker index_pointers = list(index_pointers) 48*da0073e9SAndroid Build Coastguard Worker indices = list(indices) 49*da0073e9SAndroid Build Coastguard Worker data = torch.randn(nnz, dtype=torch.double) 50*da0073e9SAndroid Build Coastguard Worker shape = (nrows, ncols) 51*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor( 52*da0073e9SAndroid Build Coastguard Worker csr_to_coo(indices, index_pointers, shape), data, shape, device=device 53*da0073e9SAndroid Build Coastguard Worker ) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerdef gen_vector(path, device): 57*da0073e9SAndroid Build Coastguard Worker with open(path) as file: 58*da0073e9SAndroid Build Coastguard Worker nrows, ncols, nnz = (int(el) for el in file.readline().split(", ")) 59*da0073e9SAndroid Build Coastguard Worker index_pointers = (int(el) for el in file.readline().split()) 60*da0073e9SAndroid Build Coastguard Worker indices = (int(el) for el in file.readline().split()) 61*da0073e9SAndroid Build Coastguard Worker return torch.randn(nrows, dtype=torch.double, device=device) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef gen_matrix(path, device): 65*da0073e9SAndroid Build Coastguard Worker with open(path) as file: 66*da0073e9SAndroid Build Coastguard Worker nrows, ncols, nnz = (int(el) for el in file.readline().split(", ")) 67*da0073e9SAndroid Build Coastguard Worker index_pointers = (int(el) for el in file.readline().split()) 68*da0073e9SAndroid Build Coastguard Worker indices = (int(el) for el in file.readline().split()) 69*da0073e9SAndroid Build Coastguard Worker return torch.randn(nrows, ncols, dtype=torch.double, device=device) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerdef load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf): 73*da0073e9SAndroid Build Coastguard Worker """load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test. 74*da0073e9SAndroid Build Coastguard Worker Args: 75*da0073e9SAndroid Build Coastguard Worker dataset_path: 76*da0073e9SAndroid Build Coastguard Worker path of the dataset from DLMC collection. 77*da0073e9SAndroid Build Coastguard Worker hidden_size 78*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sizes. 79*da0073e9SAndroid Build Coastguard Worker sparsity: 80*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sparsities. 81*da0073e9SAndroid Build Coastguard Worker device: 82*da0073e9SAndroid Build Coastguard Worker Whether to place the Tensor on a GPU or CPU. 83*da0073e9SAndroid Build Coastguard Worker n_limit: 84*da0073e9SAndroid Build Coastguard Worker This value allows a dataset with some limit size. 85*da0073e9SAndroid Build Coastguard Worker """ 86*da0073e9SAndroid Build Coastguard Worker current_folder_path = f"{dataset_path}/{sparsity}" 87*da0073e9SAndroid Build Coastguard Worker path = Path(current_folder_path) 88*da0073e9SAndroid Build Coastguard Worker files = path.glob("**/*.smtx") 89*da0073e9SAndroid Build Coastguard Worker print(dataset_path, hidden_size, sparsity) 90*da0073e9SAndroid Build Coastguard Worker index = 0 91*da0073e9SAndroid Build Coastguard Worker x_files, y_files = [], [] 92*da0073e9SAndroid Build Coastguard Worker for f in files: 93*da0073e9SAndroid Build Coastguard Worker if index >= n_limit: 94*da0073e9SAndroid Build Coastguard Worker break 95*da0073e9SAndroid Build Coastguard Worker print(".", end="") 96*da0073e9SAndroid Build Coastguard Worker size, nnz = read_matrix_params(f.as_posix()) 97*da0073e9SAndroid Build Coastguard Worker if size[1] == hidden_size: 98*da0073e9SAndroid Build Coastguard Worker x_files.append(f.as_posix()) 99*da0073e9SAndroid Build Coastguard Worker if size[0] == hidden_size: 100*da0073e9SAndroid Build Coastguard Worker y_files.append(f.as_posix()) 101*da0073e9SAndroid Build Coastguard Worker index += 1 102*da0073e9SAndroid Build Coastguard Worker print() 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker for fx, fy in zip(x_files, y_files): 105*da0073e9SAndroid Build Coastguard Worker x = load_sparse_matrix(fx, device) 106*da0073e9SAndroid Build Coastguard Worker y = gen_vector(fy, device) 107*da0073e9SAndroid Build Coastguard Worker yield (x, y) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerdef load_spmm_dataset( 111*da0073e9SAndroid Build Coastguard Worker dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf 112*da0073e9SAndroid Build Coastguard Worker): 113*da0073e9SAndroid Build Coastguard Worker """load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test. 114*da0073e9SAndroid Build Coastguard Worker Args: 115*da0073e9SAndroid Build Coastguard Worker dataset_path: 116*da0073e9SAndroid Build Coastguard Worker path of the dataset from DLMC collection. 117*da0073e9SAndroid Build Coastguard Worker hidden_size 118*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sizes. 119*da0073e9SAndroid Build Coastguard Worker sparsity: 120*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sparsities. 121*da0073e9SAndroid Build Coastguard Worker spmm_type: 122*da0073e9SAndroid Build Coastguard Worker This value allows tensors for `sparse@sparse` or `sparse@dense` operations. 123*da0073e9SAndroid Build Coastguard Worker device: 124*da0073e9SAndroid Build Coastguard Worker Whether to place the Tensor on a GPU or CPU. 125*da0073e9SAndroid Build Coastguard Worker n_limit: 126*da0073e9SAndroid Build Coastguard Worker This value allows a dataset with some limit size. 127*da0073e9SAndroid Build Coastguard Worker """ 128*da0073e9SAndroid Build Coastguard Worker current_folder_path = f"{dataset_path}/{sparsity}" 129*da0073e9SAndroid Build Coastguard Worker path = Path(current_folder_path) 130*da0073e9SAndroid Build Coastguard Worker files = path.glob("**/*.smtx") 131*da0073e9SAndroid Build Coastguard Worker print(dataset_path, hidden_size, sparsity) 132*da0073e9SAndroid Build Coastguard Worker index = 0 133*da0073e9SAndroid Build Coastguard Worker x_files, y_files = [], [] 134*da0073e9SAndroid Build Coastguard Worker for f in files: 135*da0073e9SAndroid Build Coastguard Worker if index >= n_limit: 136*da0073e9SAndroid Build Coastguard Worker break 137*da0073e9SAndroid Build Coastguard Worker print(".", end="") 138*da0073e9SAndroid Build Coastguard Worker size, nnz = read_matrix_params(f.as_posix()) 139*da0073e9SAndroid Build Coastguard Worker if size[1] == hidden_size: 140*da0073e9SAndroid Build Coastguard Worker x_files.append(f.as_posix()) 141*da0073e9SAndroid Build Coastguard Worker if size[0] == hidden_size: 142*da0073e9SAndroid Build Coastguard Worker y_files.append(f.as_posix()) 143*da0073e9SAndroid Build Coastguard Worker index += 1 144*da0073e9SAndroid Build Coastguard Worker print() 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker for fx, fy in zip(x_files, y_files): 147*da0073e9SAndroid Build Coastguard Worker x = load_sparse_matrix(fx, device) 148*da0073e9SAndroid Build Coastguard Worker y = ( 149*da0073e9SAndroid Build Coastguard Worker gen_matrix(fy, device) 150*da0073e9SAndroid Build Coastguard Worker if spmm_type == "sparse@dense" 151*da0073e9SAndroid Build Coastguard Worker else load_sparse_matrix(fy, device) 152*da0073e9SAndroid Build Coastguard Worker ) 153*da0073e9SAndroid Build Coastguard Worker yield (x, y) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workerdef load_dlmc_dataset( 157*da0073e9SAndroid Build Coastguard Worker dataset_path, 158*da0073e9SAndroid Build Coastguard Worker operation, 159*da0073e9SAndroid Build Coastguard Worker hidden_size, 160*da0073e9SAndroid Build Coastguard Worker sparsity, 161*da0073e9SAndroid Build Coastguard Worker device, 162*da0073e9SAndroid Build Coastguard Worker requires_grad, 163*da0073e9SAndroid Build Coastguard Worker n_limit=math.inf, 164*da0073e9SAndroid Build Coastguard Worker): 165*da0073e9SAndroid Build Coastguard Worker """load_dlmc_dataset loads a DLMC dataset for a matmul performance test. 166*da0073e9SAndroid Build Coastguard Worker Args: 167*da0073e9SAndroid Build Coastguard Worker dataset_path: 168*da0073e9SAndroid Build Coastguard Worker path of the dataset from DLMC collection. 169*da0073e9SAndroid Build Coastguard Worker operation: 170*da0073e9SAndroid Build Coastguard Worker This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations. 171*da0073e9SAndroid Build Coastguard Worker hidden_size 172*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sizes. 173*da0073e9SAndroid Build Coastguard Worker sparsity: 174*da0073e9SAndroid Build Coastguard Worker This value allows tensors of varying sparsities. 175*da0073e9SAndroid Build Coastguard Worker device: 176*da0073e9SAndroid Build Coastguard Worker Whether to place the Tensor on a GPU or CPU. 177*da0073e9SAndroid Build Coastguard Worker requires_grad: 178*da0073e9SAndroid Build Coastguard Worker Loads the dataset for backward test. 179*da0073e9SAndroid Build Coastguard Worker n_limit: 180*da0073e9SAndroid Build Coastguard Worker This value allows a dataset with some limit size. 181*da0073e9SAndroid Build Coastguard Worker """ 182*da0073e9SAndroid Build Coastguard Worker if operation == "sparse@sparse" or operation == "sparse@dense": 183*da0073e9SAndroid Build Coastguard Worker collection = load_spmm_dataset( 184*da0073e9SAndroid Build Coastguard Worker dataset_path, hidden_size, sparsity, operation, device, n_limit 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker elif operation == "sparse@vector": 187*da0073e9SAndroid Build Coastguard Worker collection = load_spmv_dataset( 188*da0073e9SAndroid Build Coastguard Worker dataset_path, hidden_size, sparsity, device, n_limit 189*da0073e9SAndroid Build Coastguard Worker ) 190*da0073e9SAndroid Build Coastguard Worker scipy_vars = {} 191*da0073e9SAndroid Build Coastguard Worker backward_vars = {} 192*da0073e9SAndroid Build Coastguard Worker for x, y in collection: 193*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 194*da0073e9SAndroid Build Coastguard Worker scipy_vars = { 195*da0073e9SAndroid Build Coastguard Worker "sx": to_coo_scipy(x) if x.is_sparse else x.numpy(), 196*da0073e9SAndroid Build Coastguard Worker "sy": to_coo_scipy(y) if y.is_sparse else y.numpy(), 197*da0073e9SAndroid Build Coastguard Worker } 198*da0073e9SAndroid Build Coastguard Worker if not requires_grad: 199*da0073e9SAndroid Build Coastguard Worker dx = x.to_dense() if x.is_sparse else x 200*da0073e9SAndroid Build Coastguard Worker dy = y.to_dense() if y.is_sparse else y 201*da0073e9SAndroid Build Coastguard Worker else: 202*da0073e9SAndroid Build Coastguard Worker c = sparse_grad_output(x, y) 203*da0073e9SAndroid Build Coastguard Worker backward_vars = { 204*da0073e9SAndroid Build Coastguard Worker "sparse_grad_output": c, 205*da0073e9SAndroid Build Coastguard Worker "grad_output": c.to_dense() if c.is_sparse else c, 206*da0073e9SAndroid Build Coastguard Worker } 207*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 208*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 209*da0073e9SAndroid Build Coastguard Worker dx = x.to_dense().detach() if x.is_sparse else x.clone().detach() 210*da0073e9SAndroid Build Coastguard Worker dy = y.to_dense().detach() if y.is_sparse else y.clone().detach() 211*da0073e9SAndroid Build Coastguard Worker dx.requires_grad_(True) 212*da0073e9SAndroid Build Coastguard Worker dy.requires_grad_(True) 213*da0073e9SAndroid Build Coastguard Worker yield {"x": x, "y": y, "dx": dx, "dy": dy, **scipy_vars, **backward_vars} 214