xref: /aosp_15_r20/external/cronet/third_party/rust/serde_json_lenient/v0_2/wrapper/visitor.rs (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2022 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use crate::{ContextPointer, Functions};
6 use serde::de::{DeserializeSeed, Deserializer, Error, MapAccess, SeqAccess, Visitor};
7 use std::convert::TryFrom;
8 use std::fmt;
9 use std::pin::Pin;
10 
11 /// Watches to ensure recursion does not go too deep during deserialization.
12 struct RecursionDepthCheck(usize);
13 
14 impl RecursionDepthCheck {
15     /// Recurse a level and return an error if we've recursed too far.
recurse<E: Error>(&self) -> Result<RecursionDepthCheck, E>16     fn recurse<E: Error>(&self) -> Result<RecursionDepthCheck, E> {
17         match self.0.checked_sub(1) {
18             Some(recursion_limit) => Ok(RecursionDepthCheck(recursion_limit)),
19             None => Err(Error::custom("recursion limit exceeded")),
20         }
21     }
22 }
23 
24 /// What type of aggregate JSON type is being deserialized.
25 pub enum DeserializationTarget<'c> {
26     /// Deserialize by appending to a list.
27     List { ctx: Pin<&'c mut ContextPointer> },
28     /// Deserialize by setting a dictionary key.
29     Dict { ctx: Pin<&'c mut ContextPointer>, key: String },
30 }
31 
32 /// A deserializer and visitor type that is used to visit each value in the JSON
33 /// input when it is deserialized.
34 ///
35 /// Normally serde deserialization instantiates a new object, but this visitor
36 /// is designed to call back into C++ for creating the deserialized objects. To
37 /// achieve this we use a feature of serde called "stateful deserialization" (https://docs.serde.rs/serde/de/trait.DeserializeSeed.html).
38 pub struct ValueVisitor<'c> {
39     fns: &'static Functions,
40     aggregate: DeserializationTarget<'c>,
41     recursion_depth_check: RecursionDepthCheck,
42 }
43 
44 impl<'c> ValueVisitor<'c> {
new( fns: &'static Functions, target: DeserializationTarget<'c>, max_depth: usize, ) -> Self45     pub fn new(
46         fns: &'static Functions,
47         target: DeserializationTarget<'c>,
48         max_depth: usize,
49     ) -> Self {
50         Self {
51             fns,
52             aggregate: target,
53             // The `max_depth` includes the top level of the JSON input, which is where parsing
54             // starts. We subtract 1 to count the top level now.
55             recursion_depth_check: RecursionDepthCheck(max_depth - 1),
56         }
57     }
58 }
59 
60 impl<'de, 'c> Visitor<'de> for ValueVisitor<'c> {
61     // We call out to C++ to construct the deserialized type, so no output from the
62     // visitor.
63     type Value = ();
64 
expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result65     fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
66         formatter.write_str("any valid JSON")
67     }
68 
visit_i32<E: serde::de::Error>(self, value: i32) -> Result<Self::Value, E>69     fn visit_i32<E: serde::de::Error>(self, value: i32) -> Result<Self::Value, E> {
70         match self.aggregate {
71             DeserializationTarget::List { ctx } => self.fns.list_append_i32(ctx, value),
72             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_i32(ctx, &key, value),
73         };
74         Ok(())
75     }
76 
visit_i8<E: serde::de::Error>(self, value: i8) -> Result<Self::Value, E>77     fn visit_i8<E: serde::de::Error>(self, value: i8) -> Result<Self::Value, E> {
78         self.visit_i32(value as i32)
79     }
80 
visit_bool<E: serde::de::Error>(self, value: bool) -> Result<Self::Value, E>81     fn visit_bool<E: serde::de::Error>(self, value: bool) -> Result<Self::Value, E> {
82         match self.aggregate {
83             DeserializationTarget::List { ctx } => self.fns.list_append_bool(ctx, value),
84             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_bool(ctx, &key, value),
85         };
86         Ok(())
87     }
88 
visit_i64<E: serde::de::Error>(self, value: i64) -> Result<Self::Value, E>89     fn visit_i64<E: serde::de::Error>(self, value: i64) -> Result<Self::Value, E> {
90         // Integer values that are > 32 bits large are returned as doubles instead. See
91         // JSONReaderTest.LargerIntIsLossy for a related test.
92         match i32::try_from(value) {
93             Ok(value) => self.visit_i32(value),
94             Err(_) => self.visit_f64(value as f64),
95         }
96     }
97 
visit_u64<E: serde::de::Error>(self, value: u64) -> Result<Self::Value, E>98     fn visit_u64<E: serde::de::Error>(self, value: u64) -> Result<Self::Value, E> {
99         // See visit_i64 comment.
100         match i32::try_from(value) {
101             Ok(value) => self.visit_i32(value),
102             Err(_) => self.visit_f64(value as f64),
103         }
104     }
105 
visit_f64<E: serde::de::Error>(self, value: f64) -> Result<Self::Value, E>106     fn visit_f64<E: serde::de::Error>(self, value: f64) -> Result<Self::Value, E> {
107         match self.aggregate {
108             DeserializationTarget::List { ctx } => self.fns.list_append_f64(ctx, value),
109             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_f64(ctx, &key, value),
110         };
111         Ok(())
112     }
113 
visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E>114     fn visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E> {
115         match self.aggregate {
116             DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value),
117             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value),
118         };
119         Ok(())
120     }
121 
visit_borrowed_str<E: serde::de::Error>(self, value: &'de str) -> Result<Self::Value, E>122     fn visit_borrowed_str<E: serde::de::Error>(self, value: &'de str) -> Result<Self::Value, E> {
123         match self.aggregate {
124             DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value),
125             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value),
126         };
127         Ok(())
128     }
129 
visit_string<E: serde::de::Error>(self, value: String) -> Result<Self::Value, E>130     fn visit_string<E: serde::de::Error>(self, value: String) -> Result<Self::Value, E> {
131         self.visit_str(&value)
132     }
133 
visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E>134     fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
135         match self.aggregate {
136             DeserializationTarget::List { ctx } => self.fns.list_append_none(ctx),
137             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_none(ctx, &key),
138         };
139         Ok(())
140     }
141 
visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E>142     fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
143         self.visit_none()
144     }
145 
visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error> where M: MapAccess<'de>,146     fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
147     where
148         M: MapAccess<'de>,
149     {
150         // TODO(danakj): base::Value::Dict doesn't expose a way to reserve space, so we
151         // don't bother using `access.size_hint()` here, unlike when creating a
152         // list.
153         let mut inner_ctx = match self.aggregate {
154             DeserializationTarget::List { ctx } => self.fns.list_append_dict(ctx),
155             DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_dict(ctx, &key),
156         };
157         while let Some(key) = access.next_key::<String>()? {
158             access.next_value_seed(ValueVisitor {
159                 fns: self.fns,
160                 aggregate: DeserializationTarget::Dict { ctx: inner_ctx.as_mut(), key },
161                 recursion_depth_check: self.recursion_depth_check.recurse()?,
162             })?;
163         }
164         Ok(())
165     }
166 
visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error> where S: SeqAccess<'de>,167     fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
168     where
169         S: SeqAccess<'de>,
170     {
171         let mut inner_ctx = match self.aggregate {
172             DeserializationTarget::List { ctx } => {
173                 self.fns.list_append_list(ctx, access.size_hint().unwrap_or(0))
174             }
175             DeserializationTarget::Dict { ctx, key } => {
176                 self.fns.dict_set_list(ctx, &key, access.size_hint().unwrap_or(0))
177             }
178         };
179         while let Some(_) = access.next_element_seed(ValueVisitor {
180             fns: self.fns,
181             aggregate: DeserializationTarget::List { ctx: inner_ctx.as_mut() },
182             recursion_depth_check: self.recursion_depth_check.recurse()?,
183         })? {}
184         Ok(())
185     }
186 }
187 
188 impl<'de, 'c> DeserializeSeed<'de> for ValueVisitor<'c> {
189     // We call out to C++ to construct the deserialized type, so no output from
190     // here.
191     type Value = ();
192 
deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> where D: Deserializer<'de>,193     fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
194     where
195         D: Deserializer<'de>,
196     {
197         deserializer.deserialize_any(self)
198     }
199 }
200