1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2015 Robert Jarzmik <[email protected]>
4  *
5  * Scatterlist splitting helpers.
6  */
7 
8 #include <linux/scatterlist.h>
9 #include <linux/slab.h>
10 
11 struct sg_splitter {
12 	struct scatterlist *in_sg0;
13 	int nents;
14 	off_t skip_sg0;
15 	unsigned int length_last_sg;
16 
17 	struct scatterlist *out_sg;
18 };
19 
sg_calculate_split(struct scatterlist * in,int nents,int nb_splits,off_t skip,const size_t * sizes,struct sg_splitter * splitters,bool mapped)20 static int sg_calculate_split(struct scatterlist *in, int nents, int nb_splits,
21 			      off_t skip, const size_t *sizes,
22 			      struct sg_splitter *splitters, bool mapped)
23 {
24 	int i;
25 	unsigned int sglen;
26 	size_t size = sizes[0], len;
27 	struct sg_splitter *curr = splitters;
28 	struct scatterlist *sg;
29 
30 	for (i = 0; i < nb_splits; i++) {
31 		splitters[i].in_sg0 = NULL;
32 		splitters[i].nents = 0;
33 	}
34 
35 	for_each_sg(in, sg, nents, i) {
36 		sglen = mapped ? sg_dma_len(sg) : sg->length;
37 		if (skip > sglen) {
38 			skip -= sglen;
39 			continue;
40 		}
41 
42 		len = min_t(size_t, size, sglen - skip);
43 		if (!curr->in_sg0) {
44 			curr->in_sg0 = sg;
45 			curr->skip_sg0 = skip;
46 		}
47 		size -= len;
48 		curr->nents++;
49 		curr->length_last_sg = len;
50 
51 		while (!size && (skip + len < sglen) && (--nb_splits > 0)) {
52 			curr++;
53 			size = *(++sizes);
54 			skip += len;
55 			len = min_t(size_t, size, sglen - skip);
56 
57 			curr->in_sg0 = sg;
58 			curr->skip_sg0 = skip;
59 			curr->nents = 1;
60 			curr->length_last_sg = len;
61 			size -= len;
62 		}
63 		skip = 0;
64 
65 		if (!size && --nb_splits > 0) {
66 			curr++;
67 			size = *(++sizes);
68 		}
69 
70 		if (!nb_splits)
71 			break;
72 	}
73 
74 	return (size || !splitters[0].in_sg0) ? -EINVAL : 0;
75 }
76 
sg_split_phys(struct sg_splitter * splitters,const int nb_splits)77 static void sg_split_phys(struct sg_splitter *splitters, const int nb_splits)
78 {
79 	int i, j;
80 	struct scatterlist *in_sg, *out_sg;
81 	struct sg_splitter *split;
82 
83 	for (i = 0, split = splitters; i < nb_splits; i++, split++) {
84 		in_sg = split->in_sg0;
85 		out_sg = split->out_sg;
86 		for (j = 0; j < split->nents; j++, out_sg++) {
87 			*out_sg = *in_sg;
88 			if (!j) {
89 				out_sg->offset += split->skip_sg0;
90 				out_sg->length -= split->skip_sg0;
91 			}
92 			sg_dma_address(out_sg) = 0;
93 			sg_dma_len(out_sg) = 0;
94 			in_sg = sg_next(in_sg);
95 		}
96 		out_sg[-1].length = split->length_last_sg;
97 		sg_mark_end(out_sg - 1);
98 	}
99 }
100 
sg_split_mapped(struct sg_splitter * splitters,const int nb_splits)101 static void sg_split_mapped(struct sg_splitter *splitters, const int nb_splits)
102 {
103 	int i, j;
104 	struct scatterlist *in_sg, *out_sg;
105 	struct sg_splitter *split;
106 
107 	for (i = 0, split = splitters; i < nb_splits; i++, split++) {
108 		in_sg = split->in_sg0;
109 		out_sg = split->out_sg;
110 		for (j = 0; j < split->nents; j++, out_sg++) {
111 			sg_dma_address(out_sg) = sg_dma_address(in_sg);
112 			sg_dma_len(out_sg) = sg_dma_len(in_sg);
113 			if (!j) {
114 				sg_dma_address(out_sg) += split->skip_sg0;
115 				sg_dma_len(out_sg) -= split->skip_sg0;
116 			}
117 			in_sg = sg_next(in_sg);
118 		}
119 		sg_dma_len(--out_sg) = split->length_last_sg;
120 	}
121 }
122 
123 /**
124  * sg_split - split a scatterlist into several scatterlists
125  * @in: the input sg list
126  * @in_mapped_nents: the result of a dma_map_sg(in, ...), or 0 if not mapped.
127  * @skip: the number of bytes to skip in the input sg list
128  * @nb_splits: the number of desired sg outputs
129  * @split_sizes: the respective size of each output sg list in bytes
130  * @out: an array where to store the allocated output sg lists
131  * @out_mapped_nents: the resulting sg lists mapped number of sg entries. Might
132  *                    be NULL if sglist not already mapped (in_mapped_nents = 0)
133  * @gfp_mask: the allocation flag
134  *
135  * This function splits the input sg list into nb_splits sg lists, which are
136  * allocated and stored into out.
137  * The @in is split into :
138  *  - @out[0], which covers bytes [@skip .. @skip + @split_sizes[0] - 1] of @in
139  *  - @out[1], which covers bytes [@skip + split_sizes[0] ..
140  *                                 @skip + @split_sizes[0] + @split_sizes[1] -1]
141  * etc ...
142  * It will be the caller's duty to kfree() out array members.
143  *
144  * Returns 0 upon success, or error code
145  */
sg_split(struct scatterlist * in,const int in_mapped_nents,const off_t skip,const int nb_splits,const size_t * split_sizes,struct scatterlist ** out,int * out_mapped_nents,gfp_t gfp_mask)146 int sg_split(struct scatterlist *in, const int in_mapped_nents,
147 	     const off_t skip, const int nb_splits,
148 	     const size_t *split_sizes,
149 	     struct scatterlist **out, int *out_mapped_nents,
150 	     gfp_t gfp_mask)
151 {
152 	int i, ret;
153 	struct sg_splitter *splitters;
154 
155 	splitters = kcalloc(nb_splits, sizeof(*splitters), gfp_mask);
156 	if (!splitters)
157 		return -ENOMEM;
158 
159 	ret = sg_calculate_split(in, sg_nents(in), nb_splits, skip, split_sizes,
160 			   splitters, false);
161 	if (ret < 0)
162 		goto err;
163 
164 	ret = -ENOMEM;
165 	for (i = 0; i < nb_splits; i++) {
166 		splitters[i].out_sg = kmalloc_array(splitters[i].nents,
167 						    sizeof(struct scatterlist),
168 						    gfp_mask);
169 		if (!splitters[i].out_sg)
170 			goto err;
171 	}
172 
173 	/*
174 	 * The order of these 3 calls is important and should be kept.
175 	 */
176 	sg_split_phys(splitters, nb_splits);
177 	if (in_mapped_nents) {
178 		ret = sg_calculate_split(in, in_mapped_nents, nb_splits, skip,
179 					 split_sizes, splitters, true);
180 		if (ret < 0)
181 			goto err;
182 		sg_split_mapped(splitters, nb_splits);
183 	}
184 
185 	for (i = 0; i < nb_splits; i++) {
186 		out[i] = splitters[i].out_sg;
187 		if (out_mapped_nents)
188 			out_mapped_nents[i] = splitters[i].nents;
189 	}
190 
191 	kfree(splitters);
192 	return 0;
193 
194 err:
195 	for (i = 0; i < nb_splits; i++)
196 		kfree(splitters[i].out_sg);
197 	kfree(splitters);
198 	return ret;
199 }
200 EXPORT_SYMBOL(sg_split);
201