Skip to main content

apache_avro/serde/ser_schema/record/
field_default.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use serde::{Serialize, Serializer, ser::Error};
19use serde_json::Value;
20
21use crate::{
22    Schema,
23    schema::{DecimalSchema, InnerDecimalSchema, SchemaKind, UnionSchema, UuidSchema},
24    serde::ser_schema::SERIALIZING_SCHEMA_DEFAULT,
25};
26
27pub struct SchemaAwareRecordFieldDefault<'v, 's> {
28    value: &'v Value,
29    schema: &'s Schema,
30}
31
32impl<'v, 's> SchemaAwareRecordFieldDefault<'v, 's> {
33    pub fn new(value: &'v Value, schema: &'s Schema) -> Self {
34        SchemaAwareRecordFieldDefault { value, schema }
35    }
36
37    fn serialize_as_newtype_variant<S: Serializer>(
38        &self,
39        serializer: S,
40        index: usize,
41        union: &'s UnionSchema,
42    ) -> Result<S::Ok, S::Error> {
43        let value = Self::new(self.value, &union.variants()[index]);
44        serializer.serialize_newtype_variant(
45            SERIALIZING_SCHEMA_DEFAULT,
46            index as u32,
47            SERIALIZING_SCHEMA_DEFAULT,
48            &value,
49        )
50    }
51
52    fn recursive_type_check(value: &Value, schema: &'s Schema) -> bool {
53        match (value, schema) {
54            (Value::Null, Schema::Null)
55            | (Value::Bool(_), Schema::Boolean)
56            | (
57                Value::String(_),
58                Schema::Bytes
59                | Schema::String
60                | Schema::Decimal(DecimalSchema {
61                    inner: InnerDecimalSchema::Bytes,
62                    ..
63                })
64                | Schema::BigDecimal
65                | Schema::Uuid(UuidSchema::Bytes | UuidSchema::String),
66            ) => true,
67            (Value::Number(n), Schema::Int | Schema::Date | Schema::TimeMillis) if n.is_i64() => {
68                let long = n.as_i64().unwrap();
69                i32::try_from(long).is_ok()
70            }
71            (
72                Value::Number(n),
73                Schema::Long
74                | Schema::TimeMicros
75                | Schema::TimestampMillis
76                | Schema::TimestampMicros
77                | Schema::TimestampNanos
78                | Schema::LocalTimestampMillis
79                | Schema::LocalTimestampMicros
80                | Schema::LocalTimestampNanos,
81            ) if n.is_i64() => true,
82            (Value::Number(n), Schema::Float | Schema::Double) if n.as_f64().is_some() => true,
83            (
84                Value::String(s),
85                Schema::Fixed(fixed)
86                | Schema::Decimal(DecimalSchema {
87                    inner: InnerDecimalSchema::Fixed(fixed),
88                    ..
89                })
90                | Schema::Uuid(UuidSchema::Fixed(fixed))
91                | Schema::Duration(fixed),
92            ) => s.len() == fixed.size,
93            (Value::String(s), Schema::Enum(enum_schema)) => enum_schema.symbols.contains(s),
94            (Value::Object(o), Schema::Record(record)) => record.fields.iter().all(|field| {
95                if let Some(value) = o.get(&field.name) {
96                    Self::recursive_type_check(value, &field.schema)
97                } else {
98                    field.default.is_some()
99                }
100            }),
101            (Value::Object(o), Schema::Map(map)) => o
102                .values()
103                .all(|value| Self::recursive_type_check(value, &map.types)),
104            (Value::Array(a), Schema::Array(array)) => a
105                .iter()
106                .all(|value| Self::recursive_type_check(value, &array.items)),
107            (_, Schema::Union(union)) => union
108                .variants()
109                .iter()
110                .any(|variant| Self::recursive_type_check(value, variant)),
111            _ => false,
112        }
113    }
114}
115
116impl<'v, 's> Serialize for SchemaAwareRecordFieldDefault<'v, 's> {
117    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
118    where
119        S: Serializer,
120    {
121        match (&self.value, self.schema) {
122            (Value::Null, Schema::Null) => serializer.serialize_unit(),
123            (Value::Bool(boolean), Schema::Boolean) => serializer.serialize_bool(*boolean),
124            (Value::Number(n), Schema::Int | Schema::Date | Schema::TimeMillis) if n.is_i64() => {
125                let long = n.as_i64().unwrap();
126                let int = i32::try_from(long).map_err(|_| {
127                    S::Error::custom(format!("Default {long} is too large for {:?}", self.schema))
128                })?;
129                serializer.serialize_i32(int)
130            }
131            (
132                Value::Number(n),
133                Schema::Long
134                | Schema::TimeMicros
135                | Schema::TimestampMillis
136                | Schema::TimestampMicros
137                | Schema::TimestampNanos
138                | Schema::LocalTimestampMillis
139                | Schema::LocalTimestampMicros
140                | Schema::LocalTimestampNanos,
141            ) if n.is_i64() => {
142                let long = n.as_i64().unwrap();
143                serializer.serialize_i64(long)
144            }
145            (Value::Number(n), Schema::Float) if n.as_f64().is_some() => {
146                serializer.serialize_f32(n.as_f64().unwrap() as f32)
147            }
148            (Value::Number(n), Schema::Double) if n.as_f64().is_some() => {
149                serializer.serialize_f64(n.as_f64().unwrap())
150            }
151            (
152                Value::String(s),
153                Schema::Bytes
154                | Schema::Fixed(_)
155                | Schema::Uuid(UuidSchema::Bytes | UuidSchema::Fixed(_))
156                | Schema::BigDecimal
157                | Schema::Decimal(_)
158                | Schema::Duration(_),
159            ) => serializer.serialize_bytes(s.as_bytes()),
160            (Value::String(s), Schema::String | Schema::Uuid(UuidSchema::String)) => {
161                serializer.serialize_str(s)
162            }
163            (Value::String(s), Schema::Enum(enum_schema)) => {
164                let Some((variant_index, _)) = enum_schema
165                    .symbols
166                    .iter()
167                    .enumerate()
168                    .find(|(_i, symbol)| *symbol == s)
169                else {
170                    return Err(S::Error::custom(format!(
171                        "Could not find `{s}` in enum: {enum_schema:?}"
172                    )));
173                };
174
175                serializer.serialize_unit_variant(
176                    SERIALIZING_SCHEMA_DEFAULT,
177                    variant_index as u32,
178                    SERIALIZING_SCHEMA_DEFAULT,
179                )
180            }
181            // This abuses the support for flattened fields, which are also serialized as a map.
182            (Value::Object(o), Schema::Record(record)) => {
183                serializer.collect_map(record.fields.iter().filter_map(|field| {
184                    o.get(&field.name)
185                        .map(|value| (&field.name, Self::new(value, &field.schema)))
186                }))
187            }
188            (Value::Object(o), Schema::Map(map)) => {
189                serializer.collect_map(o.iter().map(|(k, v)| (k, Self::new(v, &map.types))))
190            }
191            (Value::Array(a), Schema::Array(array)) => {
192                serializer.collect_seq(a.iter().map(|v| Self::new(v, &array.items)))
193            }
194            (_, Schema::Union(union)) => {
195                if union.variants().len() == 2
196                    && let Some(null_index) = union.index_of_schema_kind(SchemaKind::Null)
197                {
198                    // Fast path for options
199                    if self.value == &Value::Null {
200                        serializer.serialize_none()
201                    } else {
202                        let some_index = (null_index + 1) & 1;
203                        let value = Self::new(self.value, &union.variants()[some_index]);
204                        serializer.serialize_some(&value)
205                    }
206                } else {
207                    // Find the first variant that can match this value
208                    for (index, variant) in union.variants().iter().enumerate() {
209                        match (self.value, variant) {
210                            (Value::Null, Schema::Null) => {
211                                let index = index as u32;
212                                return serializer.serialize_unit_variant(
213                                    SERIALIZING_SCHEMA_DEFAULT,
214                                    index,
215                                    SERIALIZING_SCHEMA_DEFAULT,
216                                );
217                            }
218                            _ if Self::recursive_type_check(self.value, variant) => {
219                                return self.serialize_as_newtype_variant(serializer, index, union);
220                            }
221                            _ => {}
222                        }
223                    }
224                    Err(S::Error::custom(format!(
225                        "Could not match default to any variant of {:?}, default: {:?}",
226                        self.schema, self.value
227                    )))
228                }
229            }
230            _ => Err(S::Error::custom(format!(
231                "Unexpected default for {:?}, default: {:?}",
232                self.schema, self.value
233            ))),
234        }
235    }
236}