blob: 8cc3b6ed4d613d24c8cf0605eae9205706089dd3 [file] [log] [blame]
Austin Schuh272c6132020-11-14 16:37:52 -08001// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::Error;
16use crate::{FlexBufferType, Reader, ReaderIterator};
17use serde::de::{
18 DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
19 VariantAccess, Visitor,
20};
21
22/// Errors that may happen when deserializing a flexbuffer with serde.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum DeserializationError {
25 Reader(Error),
26 Serde(String),
27}
28
29impl std::error::Error for DeserializationError {}
30impl std::fmt::Display for DeserializationError {
31 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
32 match self {
33 Self::Reader(r) => write!(f, "Flexbuffer Read Error: {:?}", r),
34 Self::Serde(s) => write!(f, "Serde Error: {}", s),
35 }
36 }
37}
38impl serde::de::Error for DeserializationError {
39 fn custom<T>(msg: T) -> Self
40 where
41 T: std::fmt::Display,
42 {
43 Self::Serde(format!("{}", msg))
44 }
45}
46impl std::convert::From<super::Error> for DeserializationError {
47 fn from(e: super::Error) -> Self {
48 Self::Reader(e)
49 }
50}
51
52impl<'de> SeqAccess<'de> for ReaderIterator<'de> {
53 type Error = DeserializationError;
54 fn next_element_seed<T>(
55 &mut self,
56 seed: T,
57 ) -> Result<Option<<T as DeserializeSeed<'de>>::Value>, Self::Error>
58 where
59 T: DeserializeSeed<'de>,
60 {
61 if let Some(elem) = self.next() {
62 seed.deserialize(elem).map(Some)
63 } else {
64 Ok(None)
65 }
66 }
67 fn size_hint(&self) -> Option<usize> {
68 Some(self.len())
69 }
70}
71
72struct EnumReader<'de> {
73 variant: &'de str,
74 value: Option<Reader<'de>>,
75}
76
77impl<'de> EnumAccess<'de> for EnumReader<'de> {
78 type Error = DeserializationError;
79 type Variant = Reader<'de>;
80 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
81 where
82 V: DeserializeSeed<'de>,
83 {
84 seed.deserialize(self.variant.into_deserializer())
85 .map(|v| (v, self.value.unwrap_or_default()))
86 }
87}
88
89struct MapAccessor<'de> {
90 keys: ReaderIterator<'de>,
91 vals: ReaderIterator<'de>,
92}
93impl<'de> MapAccess<'de> for MapAccessor<'de> {
94 type Error = DeserializationError;
95
96 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
97 where
98 K: DeserializeSeed<'de>,
99 {
100 if let Some(k) = self.keys.next() {
101 seed.deserialize(k).map(Some)
102 } else {
103 Ok(None)
104 }
105 }
106 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
107 where
108 V: DeserializeSeed<'de>,
109 {
110 let val = self.vals.next().ok_or(Error::IndexOutOfBounds)?;
111 seed.deserialize(val)
112 }
113}
114
115impl<'de> VariantAccess<'de> for Reader<'de> {
116 type Error = DeserializationError;
117 fn unit_variant(self) -> Result<(), Self::Error> {
118 Ok(())
119 }
120 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
121 where
122 T: DeserializeSeed<'de>,
123 {
124 seed.deserialize(self)
125 }
126 // Tuple variants have an internally tagged representation. They are vectors where Index 0 is
127 // the discriminant and index N is field N-1.
128 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
129 where
130 V: Visitor<'de>,
131 {
132 visitor.visit_seq(self.as_vector().iter())
133 }
134 // Struct variants have an internally tagged representation. They are vectors where Index 0 is
135 // the discriminant and index N is field N-1.
136 fn struct_variant<V>(
137 self,
138 _fields: &'static [&'static str],
139 visitor: V,
140 ) -> Result<V::Value, Self::Error>
141 where
142 V: Visitor<'de>,
143 {
144 let m = self.get_map()?;
145 visitor.visit_map(MapAccessor {
146 keys: m.keys_vector().iter(),
147 vals: m.iter_values(),
148 })
149 }
150}
151
152impl<'de> Deserializer<'de> for crate::Reader<'de> {
153 type Error = DeserializationError;
154 fn is_human_readable(&self) -> bool {
155 cfg!(deserialize_human_readable)
156 }
157
158 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
159 where
160 V: Visitor<'de>,
161 {
162 use crate::BitWidth::*;
163 use crate::FlexBufferType::*;
164 match (self.flexbuffer_type(), self.bitwidth()) {
165 (Bool, _) => visitor.visit_bool(self.as_bool()),
166 (UInt, W8) => visitor.visit_u8(self.as_u8()),
167 (UInt, W16) => visitor.visit_u16(self.as_u16()),
168 (UInt, W32) => visitor.visit_u32(self.as_u32()),
169 (UInt, W64) => visitor.visit_u64(self.as_u64()),
170 (Int, W8) => visitor.visit_i8(self.as_i8()),
171 (Int, W16) => visitor.visit_i16(self.as_i16()),
172 (Int, W32) => visitor.visit_i32(self.as_i32()),
173 (Int, W64) => visitor.visit_i64(self.as_i64()),
174 (Float, W32) => visitor.visit_f32(self.as_f32()),
175 (Float, W64) => visitor.visit_f64(self.as_f64()),
176 (Float, _) => Err(Error::InvalidPackedType.into()), // f8 and f16 are not supported.
177 (Null, _) => visitor.visit_unit(),
178 (String, _) | (Key, _) => visitor.visit_borrowed_str(self.as_str()),
179 (Blob, _) => visitor.visit_borrowed_bytes(self.get_blob()?.0),
180 (Map, _) => {
181 let m = self.get_map()?;
182 visitor.visit_map(MapAccessor {
183 keys: m.keys_vector().iter(),
184 vals: m.iter_values(),
185 })
186 }
187 (ty, _) if ty.is_vector() => visitor.visit_seq(self.as_vector().iter()),
188 (ty, bw) => unreachable!("TODO deserialize_any {:?} {:?}.", ty, bw),
189 }
190 }
191 serde::forward_to_deserialize_any! {
192 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str unit unit_struct bytes
193 ignored_any map identifier struct tuple tuple_struct seq string
194 }
195 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
196 where
197 V: Visitor<'de>,
198 {
199 visitor.visit_char(self.as_u8() as char)
200 }
201 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202 where
203 V: Visitor<'de>,
204 {
205 visitor.visit_byte_buf(self.get_blob()?.0.to_vec())
206 }
207 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
208 where
209 V: Visitor<'de>,
210 {
211 if self.flexbuffer_type() == FlexBufferType::Null {
212 visitor.visit_none()
213 } else {
214 visitor.visit_some(self)
215 }
216 }
217 fn deserialize_newtype_struct<V>(
218 self,
219 _name: &'static str,
220 visitor: V,
221 ) -> Result<V::Value, Self::Error>
222 where
223 V: Visitor<'de>,
224 {
225 visitor.visit_newtype_struct(self)
226 }
227 fn deserialize_enum<V>(
228 self,
229 _name: &'static str,
230 _variants: &'static [&'static str],
231 visitor: V,
232 ) -> Result<V::Value, Self::Error>
233 where
234 V: Visitor<'de>,
235 {
236 let (variant, value) = match self.fxb_type {
237 FlexBufferType::String => (self.as_str(), None),
238 FlexBufferType::Map => {
239 let m = self.get_map()?;
240 let variant = m.keys_vector().idx(0).get_key()?;
241 let value = Some(m.idx(0));
242 (variant, value)
243 }
244 _ => {
245 return Err(Error::UnexpectedFlexbufferType {
246 expected: FlexBufferType::Map,
247 actual: self.fxb_type,
248 }
249 .into());
250 }
251 };
252 visitor.visit_enum(EnumReader { variant, value })
253 }
254}