xref: /aosp_15_r20/system/core/fs_mgr/libsnapshot/libsnapshot_cow/create_cow.cpp (revision 00c7fec1bb09f3284aad6a6f96d2f63dfc3650ad)
1 #include <linux/types.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <sys/resource.h>
5 #include <sys/time.h>
6 #include <sys/types.h>
7 #include <unistd.h>
8 
9 #include <condition_variable>
10 #include <cstring>
11 #include <future>
12 #include <iostream>
13 #include <limits>
14 #include <mutex>
15 #include <string>
16 #include <thread>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include <android-base/file.h>
21 #include <android-base/logging.h>
22 #include <android-base/stringprintf.h>
23 #include <android-base/unique_fd.h>
24 #include <ext4_utils/ext4_utils.h>
25 #include <storage_literals/storage_literals.h>
26 
27 #include <android-base/chrono_utils.h>
28 #include <android-base/scopeguard.h>
29 #include <android-base/strings.h>
30 
31 #include <gflags/gflags.h>
32 #include <libsnapshot/cow_writer.h>
33 
34 #include <openssl/sha.h>
35 
36 DEFINE_string(source, "", "Source partition image");
37 DEFINE_string(target, "", "Target partition image");
38 DEFINE_string(compression, "lz4",
39               "Compression algorithm. Default is set to lz4. Available options: lz4, zstd, gz");
40 
41 namespace android {
42 namespace snapshot {
43 
44 using namespace android::storage_literals;
45 using namespace android;
46 using android::base::unique_fd;
47 
48 using android::snapshot::CreateCowWriter;
49 using android::snapshot::ICowWriter;
50 
51 class CreateSnapshot {
52   public:
53     CreateSnapshot(const std::string& src_file, const std::string& target_file,
54                    const std::string& patch_file, const std::string& compression);
55     bool CreateSnapshotPatch();
56 
57   private:
58     /* source.img */
59     std::string src_file_;
60     /* target.img */
61     std::string target_file_;
62     /* snapshot-patch generated */
63     std::string patch_file_;
64 
65     /*
66      * Active file which is being parsed by this instance.
67      * It will either be source.img or target.img.
68      */
69     std::string parsing_file_;
70     bool create_snapshot_patch_ = false;
71 
72     const int kNumThreads = 6;
73     const size_t kBlockSizeToRead = 1_MiB;
74     const size_t compression_factor_ = 64_KiB;
75     size_t replace_ops_ = 0, copy_ops_ = 0, zero_ops_ = 0, in_place_ops_ = 0;
76 
77     std::unordered_map<std::string, int> source_block_hash_;
78     std::mutex source_block_hash_lock_;
79 
80     std::unique_ptr<ICowWriter> writer_;
81     std::mutex write_lock_;
82 
83     std::unique_ptr<uint8_t[]> zblock_;
84 
85     std::string compression_ = "lz4";
86     unique_fd cow_fd_;
87     unique_fd target_fd_;
88 
89     std::vector<uint64_t> zero_blocks_;
90     std::vector<uint64_t> replace_blocks_;
91     std::unordered_map<uint64_t, uint64_t> copy_blocks_;
92 
93     const int BLOCK_SZ = 4_KiB;
94     void SHA256(const void* data, size_t length, uint8_t out[32]);
IsBlockAligned(uint64_t read_size)95     bool IsBlockAligned(uint64_t read_size) { return ((read_size & (BLOCK_SZ - 1)) == 0); }
96     bool ReadBlocks(off_t offset, const int skip_blocks, const uint64_t dev_sz);
97     std::string ToHexString(const uint8_t* buf, size_t len);
98 
99     bool CreateSnapshotFile();
100     bool FindSourceBlockHash();
101     bool PrepareParse(std::string& parsing_file, const bool createSnapshot);
102     bool ParsePartition();
103     void PrepareMergeBlock(const void* buffer, uint64_t block, std::string& block_hash);
104     bool WriteV3Snapshots();
105     size_t PrepareWrite(size_t* pending_ops, size_t start_index);
106 
107     bool CreateSnapshotWriter();
108     bool WriteOrderedSnapshots();
109     bool WriteNonOrderedSnapshots();
110     bool VerifyMergeOrder();
111 };
112 
CreateSnapshotLogger(android::base::LogId,android::base::LogSeverity severity,const char *,const char *,unsigned int,const char * message)113 void CreateSnapshotLogger(android::base::LogId, android::base::LogSeverity severity, const char*,
114                           const char*, unsigned int, const char* message) {
115     if (severity == android::base::ERROR) {
116         fprintf(stderr, "%s\n", message);
117     } else {
118         fprintf(stdout, "%s\n", message);
119     }
120 }
121 
CreateSnapshot(const std::string & src_file,const std::string & target_file,const std::string & patch_file,const std::string & compression)122 CreateSnapshot::CreateSnapshot(const std::string& src_file, const std::string& target_file,
123                                const std::string& patch_file, const std::string& compression)
124     : src_file_(src_file), target_file_(target_file), patch_file_(patch_file) {
125     if (!compression.empty()) {
126         compression_ = compression;
127     }
128 }
129 
PrepareParse(std::string & parsing_file,const bool createSnapshot)130 bool CreateSnapshot::PrepareParse(std::string& parsing_file, const bool createSnapshot) {
131     parsing_file_ = parsing_file;
132     create_snapshot_patch_ = createSnapshot;
133 
134     if (createSnapshot) {
135         cow_fd_.reset(open(patch_file_.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666));
136         if (cow_fd_ < 0) {
137             PLOG(ERROR) << "Failed to open the snapshot-patch file: " << patch_file_;
138             return false;
139         }
140 
141         target_fd_.reset((open(parsing_file_.c_str(), O_RDONLY)));
142         if (target_fd_ < 0) {
143             LOG(ERROR) << "open failed: " << parsing_file_;
144             return false;
145         }
146         zblock_ = std::make_unique<uint8_t[]>(BLOCK_SZ);
147         std::memset(zblock_.get(), 0, BLOCK_SZ);
148     }
149     return true;
150 }
151 
152 /*
153  * Create per-block sha256 hash of source partition
154  */
FindSourceBlockHash()155 bool CreateSnapshot::FindSourceBlockHash() {
156     if (!PrepareParse(src_file_, false)) {
157         return false;
158     }
159     return ParsePartition();
160 }
161 
162 /*
163  * Create snapshot file by comparing sha256 per block
164  * of target.img with the constructed per-block sha256 hash
165  * of source partition.
166  */
CreateSnapshotFile()167 bool CreateSnapshot::CreateSnapshotFile() {
168     if (!PrepareParse(target_file_, true)) {
169         return false;
170     }
171     return ParsePartition();
172 }
173 
174 /*
175  * Creates snapshot patch file by comparing source.img and target.img
176  */
CreateSnapshotPatch()177 bool CreateSnapshot::CreateSnapshotPatch() {
178     if (!FindSourceBlockHash()) {
179         return false;
180     }
181     return CreateSnapshotFile();
182 }
183 
SHA256(const void * data,size_t length,uint8_t out[32])184 void CreateSnapshot::SHA256(const void* data, size_t length, uint8_t out[32]) {
185     SHA256_CTX c;
186     SHA256_Init(&c);
187     SHA256_Update(&c, data, length);
188     SHA256_Final(out, &c);
189 }
190 
ToHexString(const uint8_t * buf,size_t len)191 std::string CreateSnapshot::ToHexString(const uint8_t* buf, size_t len) {
192     char lookup[] = "0123456789abcdef";
193     std::string out(len * 2 + 1, '\0');
194     char* outp = out.data();
195     for (; len > 0; len--, buf++) {
196         *outp++ = (char)lookup[*buf >> 4];
197         *outp++ = (char)lookup[*buf & 0xf];
198     }
199     return out;
200 }
201 
PrepareMergeBlock(const void * buffer,uint64_t block,std::string & block_hash)202 void CreateSnapshot::PrepareMergeBlock(const void* buffer, uint64_t block,
203                                        std::string& block_hash) {
204     if (std::memcmp(zblock_.get(), buffer, BLOCK_SZ) == 0) {
205         std::lock_guard<std::mutex> lock(write_lock_);
206         zero_blocks_.push_back(block);
207         return;
208     }
209 
210     auto iter = source_block_hash_.find(block_hash);
211     if (iter != source_block_hash_.end()) {
212         std::lock_guard<std::mutex> lock(write_lock_);
213         // In-place copy is skipped
214         if (block != iter->second) {
215             copy_blocks_[block] = iter->second;
216         } else {
217             in_place_ops_ += 1;
218         }
219         return;
220     }
221     std::lock_guard<std::mutex> lock(write_lock_);
222     replace_blocks_.push_back(block);
223 }
224 
PrepareWrite(size_t * pending_ops,size_t start_index)225 size_t CreateSnapshot::PrepareWrite(size_t* pending_ops, size_t start_index) {
226     size_t num_ops = *pending_ops;
227     uint64_t start_block = replace_blocks_[start_index];
228     size_t nr_consecutive = 1;
229     num_ops -= 1;
230     while (num_ops) {
231         uint64_t next_block = replace_blocks_[start_index + nr_consecutive];
232         if (next_block != start_block + nr_consecutive) {
233             break;
234         }
235         nr_consecutive += 1;
236         num_ops -= 1;
237     }
238     return nr_consecutive;
239 }
240 
CreateSnapshotWriter()241 bool CreateSnapshot::CreateSnapshotWriter() {
242     uint64_t dev_sz = lseek(target_fd_.get(), 0, SEEK_END);
243     CowOptions options;
244     options.compression = compression_;
245     options.num_compress_threads = 2;
246     options.batch_write = true;
247     options.cluster_ops = 600;
248     options.compression_factor = compression_factor_;
249     options.max_blocks = {dev_sz / options.block_size};
250     writer_ = CreateCowWriter(3, options, std::move(cow_fd_));
251     return true;
252 }
253 
WriteNonOrderedSnapshots()254 bool CreateSnapshot::WriteNonOrderedSnapshots() {
255     zero_ops_ = zero_blocks_.size();
256     for (auto it = zero_blocks_.begin(); it != zero_blocks_.end(); it++) {
257         if (!writer_->AddZeroBlocks(*it, 1)) {
258             return false;
259         }
260     }
261     std::string buffer(compression_factor_, '\0');
262 
263     replace_ops_ = replace_blocks_.size();
264     size_t blocks_to_compress = replace_blocks_.size();
265     size_t num_ops = 0;
266     size_t block_index = 0;
267     while (blocks_to_compress) {
268         num_ops = std::min((compression_factor_ / BLOCK_SZ), blocks_to_compress);
269         auto linear_blocks = PrepareWrite(&num_ops, block_index);
270         if (!android::base::ReadFullyAtOffset(target_fd_.get(), buffer.data(),
271                                               (linear_blocks * BLOCK_SZ),
272                                               replace_blocks_[block_index] * BLOCK_SZ)) {
273             LOG(ERROR) << "Failed to read at offset: " << replace_blocks_[block_index] * BLOCK_SZ
274                        << " size: " << linear_blocks * BLOCK_SZ;
275             return false;
276         }
277         if (!writer_->AddRawBlocks(replace_blocks_[block_index], buffer.data(),
278                                    linear_blocks * BLOCK_SZ)) {
279             LOG(ERROR) << "AddRawBlocks failed";
280             return false;
281         }
282 
283         block_index += linear_blocks;
284         blocks_to_compress -= linear_blocks;
285     }
286     if (!writer_->Finalize()) {
287         return false;
288     }
289     return true;
290 }
291 
WriteOrderedSnapshots()292 bool CreateSnapshot::WriteOrderedSnapshots() {
293     std::unordered_map<uint64_t, uint64_t> overwritten_blocks;
294     std::vector<std::pair<uint64_t, uint64_t>> merge_sequence;
295     for (auto it = copy_blocks_.begin(); it != copy_blocks_.end(); it++) {
296         if (overwritten_blocks.count(it->second)) {
297             replace_blocks_.push_back(it->first);
298             continue;
299         }
300         overwritten_blocks[it->first] = it->second;
301         merge_sequence.emplace_back(std::make_pair(it->first, it->second));
302     }
303     // Sort the blocks so that if the blocks are contiguous, it would help
304     // compress multiple blocks in one shot based on the compression factor.
305     std::sort(replace_blocks_.begin(), replace_blocks_.end());
306 
307     copy_ops_ = merge_sequence.size();
308     for (auto it = merge_sequence.begin(); it != merge_sequence.end(); it++) {
309         if (!writer_->AddCopy(it->first, it->second, 1)) {
310             return false;
311         }
312     }
313 
314     return true;
315 }
316 
VerifyMergeOrder()317 bool CreateSnapshot::VerifyMergeOrder() {
318     unique_fd read_fd;
319     read_fd.reset(open(patch_file_.c_str(), O_RDONLY));
320     if (read_fd < 0) {
321         PLOG(ERROR) << "Failed to open the snapshot-patch file: " << patch_file_;
322         return false;
323     }
324     CowReader reader;
325     if (!reader.Parse(read_fd)) {
326         LOG(ERROR) << "Parse failed";
327         return false;
328     }
329 
330     if (!reader.VerifyMergeOps()) {
331         LOG(ERROR) << "MergeOps Order is wrong";
332         return false;
333     }
334     return true;
335 }
336 
WriteV3Snapshots()337 bool CreateSnapshot::WriteV3Snapshots() {
338     if (!CreateSnapshotWriter()) {
339         return false;
340     }
341     if (!WriteOrderedSnapshots()) {
342         return false;
343     }
344     if (!WriteNonOrderedSnapshots()) {
345         return false;
346     }
347     if (!VerifyMergeOrder()) {
348         return false;
349     }
350 
351     LOG(INFO) << "In-place: " << in_place_ops_ << " Zero: " << zero_ops_
352               << " Replace: " << replace_ops_ << " copy: " << copy_ops_;
353     return true;
354 }
355 
ReadBlocks(off_t offset,const int skip_blocks,const uint64_t dev_sz)356 bool CreateSnapshot::ReadBlocks(off_t offset, const int skip_blocks, const uint64_t dev_sz) {
357     unique_fd fd(TEMP_FAILURE_RETRY(open(parsing_file_.c_str(), O_RDONLY)));
358     if (fd < 0) {
359         LOG(ERROR) << "open failed: " << parsing_file_;
360         return false;
361     }
362 
363     loff_t file_offset = offset;
364     const uint64_t read_sz = kBlockSizeToRead;
365     std::unique_ptr<uint8_t[]> buffer = std::make_unique<uint8_t[]>(read_sz);
366 
367     while (true) {
368         size_t to_read = std::min((dev_sz - file_offset), read_sz);
369 
370         if (!android::base::ReadFullyAtOffset(fd.get(), buffer.get(), to_read, file_offset)) {
371             LOG(ERROR) << "Failed to read block from block device: " << parsing_file_
372                        << " at offset: " << file_offset << " read-size: " << to_read
373                        << " block-size: " << dev_sz;
374             return false;
375         }
376 
377         if (!IsBlockAligned(to_read)) {
378             LOG(ERROR) << "unable to parse the un-aligned request: " << to_read;
379             return false;
380         }
381 
382         size_t num_blocks = to_read / BLOCK_SZ;
383         uint64_t buffer_offset = 0;
384         off_t foffset = file_offset;
385 
386         while (num_blocks) {
387             const void* bufptr = (char*)buffer.get() + buffer_offset;
388             uint64_t blkindex = foffset / BLOCK_SZ;
389 
390             uint8_t checksum[32];
391             SHA256(bufptr, BLOCK_SZ, checksum);
392             std::string hash = ToHexString(checksum, sizeof(checksum));
393 
394             if (create_snapshot_patch_) {
395                 PrepareMergeBlock(bufptr, blkindex, hash);
396             } else {
397                 std::lock_guard<std::mutex> lock(source_block_hash_lock_);
398                 {
399                     if (source_block_hash_.count(hash) == 0) {
400                         source_block_hash_[hash] = blkindex;
401                     }
402                 }
403             }
404             buffer_offset += BLOCK_SZ;
405             foffset += BLOCK_SZ;
406             num_blocks -= 1;
407         }
408 
409         file_offset += (skip_blocks * to_read);
410         if (file_offset >= dev_sz) {
411             break;
412         }
413     }
414 
415     return true;
416 }
417 
ParsePartition()418 bool CreateSnapshot::ParsePartition() {
419     unique_fd fd(TEMP_FAILURE_RETRY(open(parsing_file_.c_str(), O_RDONLY)));
420     if (fd < 0) {
421         LOG(ERROR) << "open failed: " << parsing_file_;
422         return false;
423     }
424 
425     uint64_t dev_sz = lseek(fd.get(), 0, SEEK_END);
426     if (!dev_sz) {
427         LOG(ERROR) << "Could not determine block device size: " << parsing_file_;
428         return false;
429     }
430 
431     if (!IsBlockAligned(dev_sz)) {
432         LOG(ERROR) << "dev_sz: " << dev_sz << " is not block aligned";
433         return false;
434     }
435 
436     int num_threads = kNumThreads;
437 
438     std::vector<std::future<bool>> threads;
439     off_t start_offset = 0;
440     const int skip_blocks = num_threads;
441 
442     while (num_threads) {
443         threads.emplace_back(std::async(std::launch::async, &CreateSnapshot::ReadBlocks, this,
444                                         start_offset, skip_blocks, dev_sz));
445         start_offset += kBlockSizeToRead;
446         num_threads -= 1;
447         if (start_offset >= dev_sz) {
448             break;
449         }
450     }
451 
452     bool ret = true;
453     for (auto& t : threads) {
454         ret = t.get() && ret;
455     }
456 
457     if (ret && create_snapshot_patch_ && !WriteV3Snapshots()) {
458         LOG(ERROR) << "Snapshot Write failed";
459         return false;
460     }
461 
462     return ret;
463 }
464 
465 }  // namespace snapshot
466 }  // namespace android
467 
468 constexpr char kUsage[] = R"(
469 NAME
470     create_snapshot - Create snapshot patches by comparing two partition images
471 
472 SYNOPSIS
473     create_snapshot --source=<source.img> --target=<target.img> --compression="<compression-algorithm"
474 
475     source.img -> Source partition image
476     target.img -> Target partition image
477     compressoin -> compression algorithm. Default set to lz4. Supported types are gz, lz4, zstd.
478 
479 EXAMPLES
480 
481    $ create_snapshot $SOURCE_BUILD/system.img $TARGET_BUILD/system.img
482    $ create_snapshot $SOURCE_BUILD/product.img $TARGET_BUILD/product.img --compression="zstd"
483 
484 )";
485 
main(int argc,char * argv[])486 int main(int argc, char* argv[]) {
487     android::base::InitLogging(argv, &android::snapshot::CreateSnapshotLogger);
488     ::gflags::SetUsageMessage(kUsage);
489     ::gflags::ParseCommandLineFlags(&argc, &argv, true);
490 
491     if (FLAGS_source.empty() || FLAGS_target.empty()) {
492         LOG(INFO) << kUsage;
493         return 0;
494     }
495 
496     std::string fname = android::base::Basename(FLAGS_target.c_str());
497     auto parts = android::base::Split(fname, ".");
498     std::string snapshotfile = parts[0] + ".patch";
499     android::snapshot::CreateSnapshot snapshot(FLAGS_source, FLAGS_target, snapshotfile,
500                                                FLAGS_compression);
501 
502     if (!snapshot.CreateSnapshotPatch()) {
503         LOG(ERROR) << "Snapshot creation failed";
504         return -1;
505     }
506 
507     LOG(INFO) << "Snapshot patch: " << snapshotfile << " created successfully";
508     return 0;
509 }
510