1 // Copyright 2022 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use crate::{ContextPointer, Functions}; 6 use serde::de::{DeserializeSeed, Deserializer, Error, MapAccess, SeqAccess, Visitor}; 7 use std::convert::TryFrom; 8 use std::fmt; 9 use std::pin::Pin; 10 11 /// Watches to ensure recursion does not go too deep during deserialization. 12 struct RecursionDepthCheck(usize); 13 14 impl RecursionDepthCheck { 15 /// Recurse a level and return an error if we've recursed too far. recurse<E: Error>(&self) -> Result<RecursionDepthCheck, E>16 fn recurse<E: Error>(&self) -> Result<RecursionDepthCheck, E> { 17 match self.0.checked_sub(1) { 18 Some(recursion_limit) => Ok(RecursionDepthCheck(recursion_limit)), 19 None => Err(Error::custom("recursion limit exceeded")), 20 } 21 } 22 } 23 24 /// What type of aggregate JSON type is being deserialized. 25 pub enum DeserializationTarget<'c> { 26 /// Deserialize by appending to a list. 27 List { ctx: Pin<&'c mut ContextPointer> }, 28 /// Deserialize by setting a dictionary key. 29 Dict { ctx: Pin<&'c mut ContextPointer>, key: String }, 30 } 31 32 /// A deserializer and visitor type that is used to visit each value in the JSON 33 /// input when it is deserialized. 34 /// 35 /// Normally serde deserialization instantiates a new object, but this visitor 36 /// is designed to call back into C++ for creating the deserialized objects. To 37 /// achieve this we use a feature of serde called "stateful deserialization" (https://docs.serde.rs/serde/de/trait.DeserializeSeed.html). 38 pub struct ValueVisitor<'c> { 39 fns: &'static Functions, 40 aggregate: DeserializationTarget<'c>, 41 recursion_depth_check: RecursionDepthCheck, 42 } 43 44 impl<'c> ValueVisitor<'c> { new( fns: &'static Functions, target: DeserializationTarget<'c>, max_depth: usize, ) -> Self45 pub fn new( 46 fns: &'static Functions, 47 target: DeserializationTarget<'c>, 48 max_depth: usize, 49 ) -> Self { 50 Self { 51 fns, 52 aggregate: target, 53 // The `max_depth` includes the top level of the JSON input, which is where parsing 54 // starts. We subtract 1 to count the top level now. 55 recursion_depth_check: RecursionDepthCheck(max_depth - 1), 56 } 57 } 58 } 59 60 impl<'de, 'c> Visitor<'de> for ValueVisitor<'c> { 61 // We call out to C++ to construct the deserialized type, so no output from the 62 // visitor. 63 type Value = (); 64 expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result65 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 66 formatter.write_str("any valid JSON") 67 } 68 visit_i32<E: serde::de::Error>(self, value: i32) -> Result<Self::Value, E>69 fn visit_i32<E: serde::de::Error>(self, value: i32) -> Result<Self::Value, E> { 70 match self.aggregate { 71 DeserializationTarget::List { ctx } => self.fns.list_append_i32(ctx, value), 72 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_i32(ctx, &key, value), 73 }; 74 Ok(()) 75 } 76 visit_i8<E: serde::de::Error>(self, value: i8) -> Result<Self::Value, E>77 fn visit_i8<E: serde::de::Error>(self, value: i8) -> Result<Self::Value, E> { 78 self.visit_i32(value as i32) 79 } 80 visit_bool<E: serde::de::Error>(self, value: bool) -> Result<Self::Value, E>81 fn visit_bool<E: serde::de::Error>(self, value: bool) -> Result<Self::Value, E> { 82 match self.aggregate { 83 DeserializationTarget::List { ctx } => self.fns.list_append_bool(ctx, value), 84 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_bool(ctx, &key, value), 85 }; 86 Ok(()) 87 } 88 visit_i64<E: serde::de::Error>(self, value: i64) -> Result<Self::Value, E>89 fn visit_i64<E: serde::de::Error>(self, value: i64) -> Result<Self::Value, E> { 90 // Integer values that are > 32 bits large are returned as doubles instead. See 91 // JSONReaderTest.LargerIntIsLossy for a related test. 92 match i32::try_from(value) { 93 Ok(value) => self.visit_i32(value), 94 Err(_) => self.visit_f64(value as f64), 95 } 96 } 97 visit_u64<E: serde::de::Error>(self, value: u64) -> Result<Self::Value, E>98 fn visit_u64<E: serde::de::Error>(self, value: u64) -> Result<Self::Value, E> { 99 // See visit_i64 comment. 100 match i32::try_from(value) { 101 Ok(value) => self.visit_i32(value), 102 Err(_) => self.visit_f64(value as f64), 103 } 104 } 105 visit_f64<E: serde::de::Error>(self, value: f64) -> Result<Self::Value, E>106 fn visit_f64<E: serde::de::Error>(self, value: f64) -> Result<Self::Value, E> { 107 match self.aggregate { 108 DeserializationTarget::List { ctx } => self.fns.list_append_f64(ctx, value), 109 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_f64(ctx, &key, value), 110 }; 111 Ok(()) 112 } 113 visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E>114 fn visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E> { 115 match self.aggregate { 116 DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value), 117 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value), 118 }; 119 Ok(()) 120 } 121 visit_borrowed_str<E: serde::de::Error>(self, value: &'de str) -> Result<Self::Value, E>122 fn visit_borrowed_str<E: serde::de::Error>(self, value: &'de str) -> Result<Self::Value, E> { 123 match self.aggregate { 124 DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value), 125 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value), 126 }; 127 Ok(()) 128 } 129 visit_string<E: serde::de::Error>(self, value: String) -> Result<Self::Value, E>130 fn visit_string<E: serde::de::Error>(self, value: String) -> Result<Self::Value, E> { 131 self.visit_str(&value) 132 } 133 visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E>134 fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> { 135 match self.aggregate { 136 DeserializationTarget::List { ctx } => self.fns.list_append_none(ctx), 137 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_none(ctx, &key), 138 }; 139 Ok(()) 140 } 141 visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E>142 fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> { 143 self.visit_none() 144 } 145 visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error> where M: MapAccess<'de>,146 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error> 147 where 148 M: MapAccess<'de>, 149 { 150 // TODO(danakj): base::Value::Dict doesn't expose a way to reserve space, so we 151 // don't bother using `access.size_hint()` here, unlike when creating a 152 // list. 153 let mut inner_ctx = match self.aggregate { 154 DeserializationTarget::List { ctx } => self.fns.list_append_dict(ctx), 155 DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_dict(ctx, &key), 156 }; 157 while let Some(key) = access.next_key::<String>()? { 158 access.next_value_seed(ValueVisitor { 159 fns: self.fns, 160 aggregate: DeserializationTarget::Dict { ctx: inner_ctx.as_mut(), key }, 161 recursion_depth_check: self.recursion_depth_check.recurse()?, 162 })?; 163 } 164 Ok(()) 165 } 166 visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error> where S: SeqAccess<'de>,167 fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error> 168 where 169 S: SeqAccess<'de>, 170 { 171 let mut inner_ctx = match self.aggregate { 172 DeserializationTarget::List { ctx } => { 173 self.fns.list_append_list(ctx, access.size_hint().unwrap_or(0)) 174 } 175 DeserializationTarget::Dict { ctx, key } => { 176 self.fns.dict_set_list(ctx, &key, access.size_hint().unwrap_or(0)) 177 } 178 }; 179 while let Some(_) = access.next_element_seed(ValueVisitor { 180 fns: self.fns, 181 aggregate: DeserializationTarget::List { ctx: inner_ctx.as_mut() }, 182 recursion_depth_check: self.recursion_depth_check.recurse()?, 183 })? {} 184 Ok(()) 185 } 186 } 187 188 impl<'de, 'c> DeserializeSeed<'de> for ValueVisitor<'c> { 189 // We call out to C++ to construct the deserialized type, so no output from 190 // here. 191 type Value = (); 192 deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> where D: Deserializer<'de>,193 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> 194 where 195 D: Deserializer<'de>, 196 { 197 deserializer.deserialize_any(self) 198 } 199 } 200