xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/dlmc/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport math
2*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom scipy import sparse
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerdef to_coo_scipy(x):
10*da0073e9SAndroid Build Coastguard Worker    indices_1 = x._indices().numpy()
11*da0073e9SAndroid Build Coastguard Worker    values_1 = x._values().numpy()
12*da0073e9SAndroid Build Coastguard Worker    return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), shape=x.shape)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef sparse_grad_output(a, b):
16*da0073e9SAndroid Build Coastguard Worker    c = torch.sparse.mm(a, b)
17*da0073e9SAndroid Build Coastguard Worker    if c.is_sparse:
18*da0073e9SAndroid Build Coastguard Worker        c2 = torch.rand_like(c.to_dense())
19*da0073e9SAndroid Build Coastguard Worker        return c2.sparse_mask(c.coalesce())
20*da0073e9SAndroid Build Coastguard Worker    else:
21*da0073e9SAndroid Build Coastguard Worker        return torch.rand_like(c)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef read_matrix_params(path):
25*da0073e9SAndroid Build Coastguard Worker    with open(path) as file:
26*da0073e9SAndroid Build Coastguard Worker        line = file.readline()
27*da0073e9SAndroid Build Coastguard Worker        nrows, ncols, nnz = (int(el) for el in line.split(", "))
28*da0073e9SAndroid Build Coastguard Worker        return (nrows, ncols), nnz
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerdef csr_to_coo(indices, indptr, shape):
32*da0073e9SAndroid Build Coastguard Worker    n_rows, n_cols = shape
33*da0073e9SAndroid Build Coastguard Worker    cols = indices
34*da0073e9SAndroid Build Coastguard Worker    rows = [0] * len(cols)
35*da0073e9SAndroid Build Coastguard Worker    for i in range(n_rows):
36*da0073e9SAndroid Build Coastguard Worker        for j in range(indptr[i], indptr[i + 1]):
37*da0073e9SAndroid Build Coastguard Worker            rows[j] = i
38*da0073e9SAndroid Build Coastguard Worker    return torch.tensor([rows, cols], dtype=torch.long)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef load_sparse_matrix(path, device):
42*da0073e9SAndroid Build Coastguard Worker    with open(path) as file:
43*da0073e9SAndroid Build Coastguard Worker        nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
44*da0073e9SAndroid Build Coastguard Worker        index_pointers = (int(el) for el in file.readline().split())
45*da0073e9SAndroid Build Coastguard Worker        indices = (int(el) for el in file.readline().split())
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    index_pointers = list(index_pointers)
48*da0073e9SAndroid Build Coastguard Worker    indices = list(indices)
49*da0073e9SAndroid Build Coastguard Worker    data = torch.randn(nnz, dtype=torch.double)
50*da0073e9SAndroid Build Coastguard Worker    shape = (nrows, ncols)
51*da0073e9SAndroid Build Coastguard Worker    return torch.sparse_coo_tensor(
52*da0073e9SAndroid Build Coastguard Worker        csr_to_coo(indices, index_pointers, shape), data, shape, device=device
53*da0073e9SAndroid Build Coastguard Worker    )
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerdef gen_vector(path, device):
57*da0073e9SAndroid Build Coastguard Worker    with open(path) as file:
58*da0073e9SAndroid Build Coastguard Worker        nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
59*da0073e9SAndroid Build Coastguard Worker        index_pointers = (int(el) for el in file.readline().split())
60*da0073e9SAndroid Build Coastguard Worker        indices = (int(el) for el in file.readline().split())
61*da0073e9SAndroid Build Coastguard Worker        return torch.randn(nrows, dtype=torch.double, device=device)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef gen_matrix(path, device):
65*da0073e9SAndroid Build Coastguard Worker    with open(path) as file:
66*da0073e9SAndroid Build Coastguard Worker        nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
67*da0073e9SAndroid Build Coastguard Worker        index_pointers = (int(el) for el in file.readline().split())
68*da0073e9SAndroid Build Coastguard Worker        indices = (int(el) for el in file.readline().split())
69*da0073e9SAndroid Build Coastguard Worker        return torch.randn(nrows, ncols, dtype=torch.double, device=device)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
73*da0073e9SAndroid Build Coastguard Worker    """load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
74*da0073e9SAndroid Build Coastguard Worker    Args:
75*da0073e9SAndroid Build Coastguard Worker        dataset_path:
76*da0073e9SAndroid Build Coastguard Worker            path of the dataset from DLMC collection.
77*da0073e9SAndroid Build Coastguard Worker        hidden_size
78*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sizes.
79*da0073e9SAndroid Build Coastguard Worker        sparsity:
80*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sparsities.
81*da0073e9SAndroid Build Coastguard Worker        device:
82*da0073e9SAndroid Build Coastguard Worker            Whether to place the Tensor on a GPU or CPU.
83*da0073e9SAndroid Build Coastguard Worker        n_limit:
84*da0073e9SAndroid Build Coastguard Worker            This value allows a dataset with some limit size.
85*da0073e9SAndroid Build Coastguard Worker    """
86*da0073e9SAndroid Build Coastguard Worker    current_folder_path = f"{dataset_path}/{sparsity}"
87*da0073e9SAndroid Build Coastguard Worker    path = Path(current_folder_path)
88*da0073e9SAndroid Build Coastguard Worker    files = path.glob("**/*.smtx")
89*da0073e9SAndroid Build Coastguard Worker    print(dataset_path, hidden_size, sparsity)
90*da0073e9SAndroid Build Coastguard Worker    index = 0
91*da0073e9SAndroid Build Coastguard Worker    x_files, y_files = [], []
92*da0073e9SAndroid Build Coastguard Worker    for f in files:
93*da0073e9SAndroid Build Coastguard Worker        if index >= n_limit:
94*da0073e9SAndroid Build Coastguard Worker            break
95*da0073e9SAndroid Build Coastguard Worker        print(".", end="")
96*da0073e9SAndroid Build Coastguard Worker        size, nnz = read_matrix_params(f.as_posix())
97*da0073e9SAndroid Build Coastguard Worker        if size[1] == hidden_size:
98*da0073e9SAndroid Build Coastguard Worker            x_files.append(f.as_posix())
99*da0073e9SAndroid Build Coastguard Worker        if size[0] == hidden_size:
100*da0073e9SAndroid Build Coastguard Worker            y_files.append(f.as_posix())
101*da0073e9SAndroid Build Coastguard Worker        index += 1
102*da0073e9SAndroid Build Coastguard Worker    print()
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    for fx, fy in zip(x_files, y_files):
105*da0073e9SAndroid Build Coastguard Worker        x = load_sparse_matrix(fx, device)
106*da0073e9SAndroid Build Coastguard Worker        y = gen_vector(fy, device)
107*da0073e9SAndroid Build Coastguard Worker        yield (x, y)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef load_spmm_dataset(
111*da0073e9SAndroid Build Coastguard Worker    dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf
112*da0073e9SAndroid Build Coastguard Worker):
113*da0073e9SAndroid Build Coastguard Worker    """load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
114*da0073e9SAndroid Build Coastguard Worker    Args:
115*da0073e9SAndroid Build Coastguard Worker        dataset_path:
116*da0073e9SAndroid Build Coastguard Worker            path of the dataset from DLMC collection.
117*da0073e9SAndroid Build Coastguard Worker        hidden_size
118*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sizes.
119*da0073e9SAndroid Build Coastguard Worker        sparsity:
120*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sparsities.
121*da0073e9SAndroid Build Coastguard Worker        spmm_type:
122*da0073e9SAndroid Build Coastguard Worker            This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
123*da0073e9SAndroid Build Coastguard Worker        device:
124*da0073e9SAndroid Build Coastguard Worker            Whether to place the Tensor on a GPU or CPU.
125*da0073e9SAndroid Build Coastguard Worker        n_limit:
126*da0073e9SAndroid Build Coastguard Worker            This value allows a dataset with some limit size.
127*da0073e9SAndroid Build Coastguard Worker    """
128*da0073e9SAndroid Build Coastguard Worker    current_folder_path = f"{dataset_path}/{sparsity}"
129*da0073e9SAndroid Build Coastguard Worker    path = Path(current_folder_path)
130*da0073e9SAndroid Build Coastguard Worker    files = path.glob("**/*.smtx")
131*da0073e9SAndroid Build Coastguard Worker    print(dataset_path, hidden_size, sparsity)
132*da0073e9SAndroid Build Coastguard Worker    index = 0
133*da0073e9SAndroid Build Coastguard Worker    x_files, y_files = [], []
134*da0073e9SAndroid Build Coastguard Worker    for f in files:
135*da0073e9SAndroid Build Coastguard Worker        if index >= n_limit:
136*da0073e9SAndroid Build Coastguard Worker            break
137*da0073e9SAndroid Build Coastguard Worker        print(".", end="")
138*da0073e9SAndroid Build Coastguard Worker        size, nnz = read_matrix_params(f.as_posix())
139*da0073e9SAndroid Build Coastguard Worker        if size[1] == hidden_size:
140*da0073e9SAndroid Build Coastguard Worker            x_files.append(f.as_posix())
141*da0073e9SAndroid Build Coastguard Worker        if size[0] == hidden_size:
142*da0073e9SAndroid Build Coastguard Worker            y_files.append(f.as_posix())
143*da0073e9SAndroid Build Coastguard Worker        index += 1
144*da0073e9SAndroid Build Coastguard Worker    print()
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    for fx, fy in zip(x_files, y_files):
147*da0073e9SAndroid Build Coastguard Worker        x = load_sparse_matrix(fx, device)
148*da0073e9SAndroid Build Coastguard Worker        y = (
149*da0073e9SAndroid Build Coastguard Worker            gen_matrix(fy, device)
150*da0073e9SAndroid Build Coastguard Worker            if spmm_type == "sparse@dense"
151*da0073e9SAndroid Build Coastguard Worker            else load_sparse_matrix(fy, device)
152*da0073e9SAndroid Build Coastguard Worker        )
153*da0073e9SAndroid Build Coastguard Worker        yield (x, y)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerdef load_dlmc_dataset(
157*da0073e9SAndroid Build Coastguard Worker    dataset_path,
158*da0073e9SAndroid Build Coastguard Worker    operation,
159*da0073e9SAndroid Build Coastguard Worker    hidden_size,
160*da0073e9SAndroid Build Coastguard Worker    sparsity,
161*da0073e9SAndroid Build Coastguard Worker    device,
162*da0073e9SAndroid Build Coastguard Worker    requires_grad,
163*da0073e9SAndroid Build Coastguard Worker    n_limit=math.inf,
164*da0073e9SAndroid Build Coastguard Worker):
165*da0073e9SAndroid Build Coastguard Worker    """load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
166*da0073e9SAndroid Build Coastguard Worker    Args:
167*da0073e9SAndroid Build Coastguard Worker        dataset_path:
168*da0073e9SAndroid Build Coastguard Worker            path of the dataset from DLMC collection.
169*da0073e9SAndroid Build Coastguard Worker        operation:
170*da0073e9SAndroid Build Coastguard Worker            This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
171*da0073e9SAndroid Build Coastguard Worker        hidden_size
172*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sizes.
173*da0073e9SAndroid Build Coastguard Worker        sparsity:
174*da0073e9SAndroid Build Coastguard Worker            This value allows tensors of varying sparsities.
175*da0073e9SAndroid Build Coastguard Worker        device:
176*da0073e9SAndroid Build Coastguard Worker            Whether to place the Tensor on a GPU or CPU.
177*da0073e9SAndroid Build Coastguard Worker        requires_grad:
178*da0073e9SAndroid Build Coastguard Worker            Loads the dataset for backward test.
179*da0073e9SAndroid Build Coastguard Worker        n_limit:
180*da0073e9SAndroid Build Coastguard Worker            This value allows a dataset with some limit size.
181*da0073e9SAndroid Build Coastguard Worker    """
182*da0073e9SAndroid Build Coastguard Worker    if operation == "sparse@sparse" or operation == "sparse@dense":
183*da0073e9SAndroid Build Coastguard Worker        collection = load_spmm_dataset(
184*da0073e9SAndroid Build Coastguard Worker            dataset_path, hidden_size, sparsity, operation, device, n_limit
185*da0073e9SAndroid Build Coastguard Worker        )
186*da0073e9SAndroid Build Coastguard Worker    elif operation == "sparse@vector":
187*da0073e9SAndroid Build Coastguard Worker        collection = load_spmv_dataset(
188*da0073e9SAndroid Build Coastguard Worker            dataset_path, hidden_size, sparsity, device, n_limit
189*da0073e9SAndroid Build Coastguard Worker        )
190*da0073e9SAndroid Build Coastguard Worker    scipy_vars = {}
191*da0073e9SAndroid Build Coastguard Worker    backward_vars = {}
192*da0073e9SAndroid Build Coastguard Worker    for x, y in collection:
193*da0073e9SAndroid Build Coastguard Worker        if device == "cpu":
194*da0073e9SAndroid Build Coastguard Worker            scipy_vars = {
195*da0073e9SAndroid Build Coastguard Worker                "sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
196*da0073e9SAndroid Build Coastguard Worker                "sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
197*da0073e9SAndroid Build Coastguard Worker            }
198*da0073e9SAndroid Build Coastguard Worker        if not requires_grad:
199*da0073e9SAndroid Build Coastguard Worker            dx = x.to_dense() if x.is_sparse else x
200*da0073e9SAndroid Build Coastguard Worker            dy = y.to_dense() if y.is_sparse else y
201*da0073e9SAndroid Build Coastguard Worker        else:
202*da0073e9SAndroid Build Coastguard Worker            c = sparse_grad_output(x, y)
203*da0073e9SAndroid Build Coastguard Worker            backward_vars = {
204*da0073e9SAndroid Build Coastguard Worker                "sparse_grad_output": c,
205*da0073e9SAndroid Build Coastguard Worker                "grad_output": c.to_dense() if c.is_sparse else c,
206*da0073e9SAndroid Build Coastguard Worker            }
207*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_(True)
208*da0073e9SAndroid Build Coastguard Worker            y.requires_grad_(True)
209*da0073e9SAndroid Build Coastguard Worker            dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
210*da0073e9SAndroid Build Coastguard Worker            dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
211*da0073e9SAndroid Build Coastguard Worker            dx.requires_grad_(True)
212*da0073e9SAndroid Build Coastguard Worker            dy.requires_grad_(True)
213*da0073e9SAndroid Build Coastguard Worker        yield {"x": x, "y": y, "dx": dx, "dy": dy, **scipy_vars, **backward_vars}
214