1 // SPDX-License-Identifier: LGPL-2.1
2 /*
3 * Copyright (C) 2022, Sebastian Andrzej Siewior <[email protected]>
4 *
5 */
6 #include <stdlib.h>
7 #include <zstd.h>
8 #include <errno.h>
9
10 #include "trace-cmd-private.h"
11
12 #define __ZSTD_NAME "zstd"
13 #define __ZSTD_WEIGTH 5
14
15 struct zstd_context {
16 ZSTD_CCtx *ctx_c;
17 ZSTD_DCtx *ctx_d;
18 };
19
zstd_compress(void * ctx,const void * in,int in_bytes,void * out,int out_bytes)20 static int zstd_compress(void *ctx, const void *in, int in_bytes, void *out, int out_bytes)
21 {
22 struct zstd_context *context = ctx;
23 size_t ret;
24
25 if (!ctx)
26 return -1;
27
28 ret = ZSTD_compress2(context->ctx_c, out, out_bytes, in, in_bytes);
29 if (ZSTD_isError(ret))
30 return -1;
31
32 return ret;
33 }
34
zstd_decompress(void * ctx,const void * in,int in_bytes,void * out,int out_bytes)35 static int zstd_decompress(void *ctx, const void *in, int in_bytes, void *out, int out_bytes)
36 {
37 struct zstd_context *context = ctx;
38 size_t ret;
39
40 if (!ctx)
41 return -1;
42
43 ret = ZSTD_decompressDCtx(context->ctx_d, out, out_bytes, in, in_bytes);
44 if (ZSTD_isError(ret)) {
45 errno = -EINVAL;
46 return -1;
47 }
48
49 return ret;
50 }
51
zstd_compress_bound(void * ctx,unsigned int in_bytes)52 static unsigned int zstd_compress_bound(void *ctx, unsigned int in_bytes)
53 {
54 return ZSTD_compressBound(in_bytes);
55 }
56
zstd_is_supported(const char * name,const char * version)57 static bool zstd_is_supported(const char *name, const char *version)
58 {
59 if (!name)
60 return false;
61 if (strcmp(name, __ZSTD_NAME))
62 return false;
63
64 return true;
65 }
66
new_zstd_context(void)67 static void *new_zstd_context(void)
68 {
69 struct zstd_context *context;
70 size_t r;
71
72 context = calloc(1, sizeof(*context));
73 if (!context)
74 return NULL;
75
76 context->ctx_c = ZSTD_createCCtx();
77 context->ctx_d = ZSTD_createDCtx();
78 if (!context->ctx_c || !context->ctx_d)
79 goto err;
80
81 r = ZSTD_CCtx_setParameter(context->ctx_c, ZSTD_c_contentSizeFlag, 0);
82 if (ZSTD_isError(r))
83 goto err;
84
85 return context;
86 err:
87 ZSTD_freeCCtx(context->ctx_c);
88 ZSTD_freeDCtx(context->ctx_d);
89 free(context);
90 return NULL;
91 }
free_zstd_context(void * ctx)92 static void free_zstd_context(void *ctx)
93 {
94 struct zstd_context *context = ctx;
95
96 if (!ctx)
97 return;
98
99 ZSTD_freeCCtx(context->ctx_c);
100 ZSTD_freeDCtx(context->ctx_d);
101 free(context);
102 }
103
tracecmd_zstd_init(void)104 int tracecmd_zstd_init(void)
105 {
106 struct tracecmd_compression_proto proto;
107
108 memset(&proto, 0, sizeof(proto));
109 proto.name = __ZSTD_NAME;
110 proto.version = ZSTD_versionString();
111 proto.weight = __ZSTD_WEIGTH;
112 proto.compress = zstd_compress;
113 proto.uncompress = zstd_decompress;
114 proto.is_supported = zstd_is_supported;
115 proto.compress_size = zstd_compress_bound;
116 proto.new_context = new_zstd_context;
117 proto.free_context = free_zstd_context;
118
119 return tracecmd_compress_proto_register(&proto);
120 }
121