1use std::{borrow::Borrow, io::Write};
19
20use serde::{Serialize, Serializer};
21
22use super::{Config, MapOrRecordSerializer, SchemaAwareSerializer};
23use crate::{
24 Error, Schema,
25 error::Details,
26 schema::{FixedSchema, SchemaKind, UnionSchema},
27 serde::{
28 ser_schema::{
29 block::BlockSerializer,
30 record::RecordSerializer,
31 tuple::{ManyTupleSerializer, TupleSerializer},
32 },
33 with::{BytesType, SER_BYTES_TYPE},
34 },
35 util::{zig_i32, zig_i64},
36};
37
38pub struct UnionSerializer<'s, 'w, W: Write, S: Borrow<Schema>> {
43 writer: &'w mut W,
44 union: &'s UnionSchema,
45 config: Config<'s, S>,
46}
47
48impl<'s, 'w, W: Write, S: Borrow<Schema>> UnionSerializer<'s, 'w, W, S> {
49 pub fn new(writer: &'w mut W, union: &'s UnionSchema, config: Config<'s, S>) -> Self {
50 UnionSerializer {
51 writer,
52 union,
53 config,
54 }
55 }
56
57 fn error(&self, ty: &'static str, error: impl Into<String>) -> Error {
58 Error::new(Details::SerializeValueWithSchema {
59 value_type: ty,
60 value: error.into(),
61 schema: Schema::Union(self.union.clone()),
62 })
63 }
64
65 pub(super) fn checked_write_int(
70 &mut self,
71 original_ty: &'static str,
72 v: i32,
73 ) -> Result<usize, Error> {
74 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Int) {
75 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
76 bytes_written += zig_i32(v, &mut *self.writer)?;
77 Ok(bytes_written)
78 } else {
79 Err(self.error(
80 original_ty,
81 "Expected Schema::Int | Schema::Date | Schema::TimeMillis in variants",
82 ))
83 }
84 }
85
86 pub(super) fn checked_write_long(
91 &mut self,
92 original_ty: &'static str,
93 v: i64,
94 ) -> Result<usize, Error> {
95 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Long) {
96 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
97 bytes_written += zig_i64(v, &mut *self.writer)?;
98 Ok(bytes_written)
99 } else {
100 Err(self.error(original_ty, "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos} in variants"))
101 }
102 }
103
104 fn write_bytes_with_len(&mut self, bytes: &[u8]) -> Result<usize, Error> {
108 let mut bytes_written = 0;
109 bytes_written += zig_i64(bytes.len() as i64, &mut *self.writer)?;
110 bytes_written += self.write_bytes(bytes)?;
111 Ok(bytes_written)
112 }
113
114 fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
118 self.writer.write_all(bytes).map_err(Details::WriteBytes)?;
119 Ok(bytes.len())
120 }
121
122 fn write_array<const N: usize>(&mut self, bytes: [u8; N]) -> Result<usize, Error> {
126 self.write_bytes(&bytes)?;
127 Ok(N)
128 }
129}
130
131impl<'s, 'w, W: Write, S: Borrow<Schema>> Serializer for UnionSerializer<'s, 'w, W, S> {
132 type Ok = usize;
133 type Error = Error;
134 type SerializeSeq = BlockSerializer<'s, 'w, W, S>;
135 type SerializeTuple = TupleSerializer<'s, 'w, W, S>;
136 type SerializeTupleStruct = ManyTupleSerializer<'s, 'w, W, S>;
137 type SerializeTupleVariant = ManyTupleSerializer<'s, 'w, W, S>;
138 type SerializeMap = MapOrRecordSerializer<'s, 'w, W, S>;
139 type SerializeStruct = RecordSerializer<'s, 'w, W, S>;
140 type SerializeStructVariant = RecordSerializer<'s, 'w, W, S>;
141
142 fn serialize_bool(mut self, v: bool) -> Result<Self::Ok, Self::Error> {
143 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Boolean) {
144 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
145 bytes_written += self.write_array([u8::from(v)])?;
146 Ok(bytes_written)
147 } else {
148 Err(self.error("bool", "Expected Schema::Boolean in variants"))
149 }
150 }
151
152 fn serialize_i8(mut self, v: i8) -> Result<Self::Ok, Self::Error> {
153 self.checked_write_int("i8", i32::from(v))
154 }
155
156 fn serialize_i16(mut self, v: i16) -> Result<Self::Ok, Self::Error> {
157 self.checked_write_int("i16", i32::from(v))
158 }
159
160 fn serialize_i32(mut self, v: i32) -> Result<Self::Ok, Self::Error> {
161 self.checked_write_int("i32", v)
162 }
163
164 fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
165 self.checked_write_long("i64", v)
166 }
167
168 fn serialize_i128(mut self, v: i128) -> Result<Self::Ok, Self::Error> {
169 match self.union.find_named_schema("i128", self.config.names)? {
170 Some((index, Schema::Fixed(FixedSchema { size: 16, .. }))) => {
171 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
172 bytes_written += self.write_array(v.to_le_bytes())?;
173 Ok(bytes_written)
174 }
175 _ => Err(self.error(
176 "i128",
177 r#"Expected Schema::Fixed(name: "i128", size: 16) in variants"#,
178 )),
179 }
180 }
181
182 fn serialize_u8(mut self, v: u8) -> Result<Self::Ok, Self::Error> {
183 self.checked_write_int("u8", i32::from(v))
184 }
185
186 fn serialize_u16(mut self, v: u16) -> Result<Self::Ok, Self::Error> {
187 self.checked_write_int("u16", i32::from(v))
188 }
189
190 fn serialize_u32(mut self, v: u32) -> Result<Self::Ok, Self::Error> {
191 self.checked_write_long("u32", i64::from(v))
192 }
193
194 fn serialize_u64(mut self, v: u64) -> Result<Self::Ok, Self::Error> {
195 match self.union.find_named_schema("u64", self.config.names)? {
196 Some((index, Schema::Fixed(FixedSchema { size: 8, .. }))) => {
197 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
198 bytes_written += self.write_array(v.to_le_bytes())?;
199 Ok(bytes_written)
200 }
201 _ => Err(self.error(
202 "u64",
203 r#"Expected Schema::Fixed(name: "u64", size: 8) in variants"#,
204 )),
205 }
206 }
207
208 fn serialize_u128(mut self, v: u128) -> Result<Self::Ok, Self::Error> {
209 match self.union.find_named_schema("u128", self.config.names)? {
210 Some((index, Schema::Fixed(FixedSchema { size: 16, .. }))) => {
211 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
212 bytes_written += self.write_array(v.to_le_bytes())?;
213 Ok(bytes_written)
214 }
215 _ => Err(self.error(
216 "u128",
217 r#"Expected Schema::Fixed(name: "u128", size: 16) in variants"#,
218 )),
219 }
220 }
221
222 fn serialize_f32(mut self, v: f32) -> Result<Self::Ok, Self::Error> {
223 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Float) {
224 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
225 bytes_written += self.write_array(v.to_le_bytes())?;
226 Ok(bytes_written)
227 } else {
228 Err(self.error("f32", "Expected Schema::Float in variants"))
229 }
230 }
231
232 fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> {
233 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Double) {
234 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
235 bytes_written += self.write_array(v.to_le_bytes())?;
236 Ok(bytes_written)
237 } else {
238 Err(self.error("f64", "Expected Schema::Double in variants"))
239 }
240 }
241
242 fn serialize_char(mut self, v: char) -> Result<Self::Ok, Self::Error> {
243 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::String) {
244 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
245 bytes_written += self.write_bytes_with_len(v.to_string().as_bytes())?;
246 Ok(bytes_written)
247 } else {
248 Err(self.error("char", "Expected Schema::String in variants"))
249 }
250 }
251
252 fn serialize_str(mut self, v: &str) -> Result<Self::Ok, Self::Error> {
253 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::String) {
254 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
255 bytes_written += self.write_bytes_with_len(v.as_bytes())?;
256 Ok(bytes_written)
257 } else {
258 Err(self.error("str", "Expected Schema::String in variants"))
259 }
260 }
261
262 fn serialize_bytes(mut self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
263 let (index, with_len) = match SER_BYTES_TYPE.get() {
264 BytesType::Bytes => {
265 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Bytes) {
266 (index, true)
267 } else {
268 return Err(self.error("bytes", "Expected Schema::Bytes in variants"));
269 }
270 }
271 BytesType::Fixed => {
272 if let Some((index, _)) = self
273 .union
274 .find_fixed_of_size_n(v.len(), self.config.names)?
275 {
276 (index, false)
277 } else {
278 return Err(self.error(
279 "bytes",
280 format!("Expected Schema::Fixed(size: {}) in variants", v.len()),
281 ));
282 }
283 }
284 BytesType::Unset => {
285 let bytes_index = self.union.index_of_schema_kind(SchemaKind::Bytes);
286 let fixed_index = self
287 .union
288 .find_fixed_of_size_n(v.len(), self.config.names)?;
289 match (bytes_index, fixed_index) {
291 (Some(bytes_index), Some((fixed_index, _))) => {
292 (bytes_index.min(fixed_index), bytes_index < fixed_index)
293 }
294 (Some(bytes_index), None) => (bytes_index, true),
295 (None, Some((fixed_index, _))) => (fixed_index, false),
296 (None, None) => {
297 return Err(self.error(
298 "bytes",
299 format!(
300 "Expected Schema::Bytes | Schema::Fixed(size: {}) in variants",
301 v.len()
302 ),
303 ));
304 }
305 }
306 }
307 };
308 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
309 if with_len {
310 bytes_written += self.write_bytes_with_len(v)?;
311 } else {
312 bytes_written += self.write_bytes(v)?;
313 }
314 Ok(bytes_written)
315 }
316
317 fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
318 Err(self.error("none", "Nested unions are not supported"))
319 }
320
321 fn serialize_some<T>(self, _: &T) -> Result<Self::Ok, Self::Error>
322 where
323 T: ?Sized + Serialize,
324 {
325 Err(self.error("some", "Nested unions are not supported"))
326 }
327
328 fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
329 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Null) {
330 zig_i32(index as i32, &mut *self.writer)
331 } else {
332 Err(self.error("unit", "Expected Schema::Null in variants"))
333 }
334 }
335
336 fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
337 match self.union.find_named_schema(name, self.config.names)? {
338 Some((index, Schema::Record(record))) if record.fields.is_empty() => {
339 zig_i32(index as i32, &mut *self.writer)
340 }
341 _ => Err(self.error(
342 "unit struct",
343 format!("Expected Schema::Record(name: {name}, fields: []) in variants"),
344 )),
345 }
346 }
347
348 fn serialize_unit_variant(
349 self,
350 _: &'static str,
351 _: u32,
352 _: &'static str,
353 ) -> Result<Self::Ok, Self::Error> {
354 Err(self.error("unit variant", "Nested unions are not supported"))
355 }
356
357 fn serialize_newtype_struct<T>(
358 self,
359 name: &'static str,
360 value: &T,
361 ) -> Result<Self::Ok, Self::Error>
362 where
363 T: ?Sized + Serialize,
364 {
365 match self.union.find_named_schema(name, self.config.names)? {
366 Some((index, Schema::Record(record))) if record.fields.len() == 1 => {
367 let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
368 bytes_written += value.serialize(SchemaAwareSerializer::new(
369 self.writer,
370 &record.fields[0].schema,
371 self.config,
372 )?)?;
373 Ok(bytes_written)
374 }
375 _ => Err(self.error(
376 "newtype struct",
377 format!("Expected Schema::Record(name: {name}, fields: [_]) in variants"),
378 )),
379 }
380 }
381
382 fn serialize_newtype_variant<T>(
383 self,
384 _: &'static str,
385 _: u32,
386 _: &'static str,
387 _: &T,
388 ) -> Result<Self::Ok, Self::Error>
389 where
390 T: ?Sized + Serialize,
391 {
392 Err(self.error("newtype variant", "Nested unions are not supported"))
393 }
394
395 fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
396 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Array)
397 && let Schema::Array(array) = &self.union.variants()[index]
398 {
399 let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
400 BlockSerializer::array(self.writer, array, self.config, len, Some(bytes_written))
401 } else {
402 Err(self.error("array", "Expected Schema::Array in variants"))
403 }
404 }
405
406 fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
407 if len == 0 {
408 if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Null) {
409 let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
410 Ok(TupleSerializer::unit(Some(bytes_written)))
411 } else {
412 Err(self.error("tuple", "Expected Schema::Null in variants for 0-tuple"))
413 }
414 } else if len == 1 {
415 Ok(TupleSerializer::one_union(
416 self.writer,
417 self.union,
418 self.config,
419 ))
420 } else if let Some((index, record)) = self
421 .union
422 .find_record_with_n_fields(len, self.config.names)?
423 {
424 let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
425 Ok(TupleSerializer::many(
426 self.writer,
427 record,
428 self.config,
429 Some(bytes_written),
430 ))
431 } else {
432 Err(self.error(
433 "tuple",
434 format!(
435 "Expected Schema::Record(fields.len() == {len}) in variants for {len}-tuple"
436 ),
437 ))
438 }
439 }
440
441 fn serialize_tuple_struct(
442 self,
443 name: &'static str,
444 len: usize,
445 ) -> Result<Self::SerializeTupleStruct, Self::Error> {
446 match self.union.find_named_schema(name, self.config.names)? {
447 Some((index, Schema::Record(record))) if record.fields.len() == len => {
448 let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
449 Ok(ManyTupleSerializer::new(
450 self.writer,
451 record,
452 self.config,
453 Some(bytes_written),
454 ))
455 }
456 _ => Err(self.error(
457 "tuple struct",
458 format!("Expected Schema::Record(name: {name}, fields.len() == {len}) in variants"),
459 )),
460 }
461 }
462
463 fn serialize_tuple_variant(
464 self,
465 _: &'static str,
466 _: u32,
467 _: &'static str,
468 _: usize,
469 ) -> Result<Self::SerializeTupleVariant, Self::Error> {
470 Err(self.error("tuple variant", "Nested unions are not supported"))
471 }
472
473 fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
474 let map_index = self.union.index_of_schema_kind(SchemaKind::Map).map(|i| {
475 if let Schema::Map(map) = &self.union.variants()[i] {
476 (i, map)
477 } else {
478 unreachable!("SchemaKind is Map so Schema must also be a Map")
479 }
480 });
481 let record_index = if let Some(len) = len {
482 self.union
483 .find_record_with_n_fields(len, self.config.names)?
484 } else {
485 None
486 };
487 match (map_index, record_index) {
488 (Some((map_index, map)), Some((record_index, record))) => {
489 let bytes_written = zig_i32(map_index.min(record_index) as i32, &mut *self.writer)?;
490 if map_index < record_index {
491 MapOrRecordSerializer::map(
492 self.writer,
493 map,
494 self.config,
495 len,
496 Some(bytes_written),
497 )
498 } else {
499 Ok(MapOrRecordSerializer::record(
500 self.writer,
501 record,
502 self.config,
503 Some(bytes_written),
504 ))
505 }
506 }
507 (Some((map_index, map)), None) => {
508 let bytes_written = zig_i32(map_index as i32, &mut *self.writer)?;
509 MapOrRecordSerializer::map(self.writer, map, self.config, len, Some(bytes_written))
510 }
511 (None, Some((record_index, record))) => {
512 let bytes_written = zig_i32(record_index as i32, &mut *self.writer)?;
513 Ok(MapOrRecordSerializer::record(
514 self.writer,
515 record,
516 self.config,
517 Some(bytes_written),
518 ))
519 }
520 (None, None) => Err(self.error(
521 "map",
522 "Expected Schema::Map or Schema::Record for structs with flattened fields in variants",
523 )),
524 }
525 }
526
527 fn serialize_struct(
528 self,
529 name: &'static str,
530 _len: usize,
531 ) -> Result<Self::SerializeStruct, Self::Error> {
532 if let Some((index, Schema::Record(record))) =
533 self.union.find_named_schema(name, self.config.names)?
534 {
535 let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
536 Ok(RecordSerializer::new(
537 self.writer,
538 record,
539 self.config,
540 Some(bytes_written),
541 ))
542 } else {
543 Err(self.error(
544 "struct",
545 format!("Expected Schema::Record(name: {name}) in variants"),
546 ))
547 }
548 }
549
550 fn serialize_struct_variant(
551 self,
552 _: &'static str,
553 _: u32,
554 _: &'static str,
555 _: usize,
556 ) -> Result<Self::SerializeStructVariant, Self::Error> {
557 Err(self.error("struct variant", "Nested unions are not supported"))
558 }
559
560 fn is_human_readable(&self) -> bool {
561 self.config.human_readable
562 }
563}