Skip to content

Commit 89a68e5

Browse files
committed
Support union serialisation/deserialisation.
1 parent c58cb5c commit 89a68e5

7 files changed

Lines changed: 320 additions & 1 deletion

File tree

Dasher.Tests/Dasher.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
<Compile Include="TestTypes.cs" />
6767
<Compile Include="TestUtils.cs" />
6868
<Compile Include="TypeSupportTests.cs" />
69+
<Compile Include="UnionProviderTests.cs" />
6970
<Compile Include="UnionTests.cs" />
7071
<Compile Include="UnpackerTests.cs" />
7172
<Compile Include="DeserialiserTests.cs" />

Dasher.Tests/DeserialiserTests.cs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,81 @@ public void ThrowsIfNullUnpacker()
579579
Assert.Equal("unpacker", ex.ParamName);
580580
}
581581

582+
[Fact]
583+
public void HandlesUnion1()
584+
{
585+
var bytes = PackBytes(packer => packer.PackMapHeader(1)
586+
.Pack(nameof(ValueWrapper<Union<int, string>>.Value))
587+
.PackArrayHeader(2)
588+
.Pack("Int32")
589+
.Pack(123));
590+
591+
var after = new Deserialiser<ValueWrapper<Union<int, string>>>().Deserialise(bytes);
592+
593+
Assert.Equal(typeof(int), after.Value.Type);
594+
Assert.Equal(123, after.Value.Value);
595+
}
596+
597+
[Fact]
598+
public void HandlesUnion2()
599+
{
600+
var bytes = PackBytes(packer => packer.PackMapHeader(1)
601+
.Pack(nameof(ValueWrapper<Union<int, string>>.Value))
602+
.PackArrayHeader(2)
603+
.Pack("String")
604+
.Pack("Hello"));
605+
606+
var after = new Deserialiser<ValueWrapper<Union<int, string>>>().Deserialise(bytes);
607+
608+
Assert.Equal("Hello", after.Value);
609+
Assert.Equal(typeof(string), after.Value.Type);
610+
Assert.Equal("Hello", after.Value.Value);
611+
}
612+
613+
[Fact]
614+
public void ThrowsIfUnionHasWrongNumberOfArrayElements()
615+
{
616+
var bytes = PackBytes(packer => packer.PackMapHeader(1)
617+
.Pack(nameof(ValueWrapper<Union<int, string>>.Value))
618+
.PackArrayHeader(3)
619+
.Pack("String")
620+
.Pack("Hello")
621+
.Pack("World"));
622+
623+
var ex = Assert.Throws<DeserialisationException>(() => new Deserialiser<ValueWrapper<Union<int, string>>>().Deserialise(bytes));
624+
625+
Assert.Equal(@"Union array should have 2 elements (not 3) for property ""value"" of type ""Dasher.Union`2[System.Int32,System.String]""",
626+
ex.Message);
627+
}
628+
629+
[Fact]
630+
public void ThrowsIfReceivedTypeNotInUnion()
631+
{
632+
var bytes = PackBytes(packer => packer.PackMapHeader(1)
633+
.Pack(nameof(ValueWrapper<Union<int, double>>.Value))
634+
.PackArrayHeader(3)
635+
.Pack("String")
636+
.Pack("Hello"));
637+
638+
var ex = Assert.Throws<DeserialisationException>(() => new Deserialiser<ValueWrapper<Union<int, string>>>().Deserialise(bytes));
639+
640+
Assert.Equal(@"Union array should have 2 elements (not 3) for property ""value"" of type ""Dasher.Union`2[System.Int32,System.String]""",
641+
ex.Message);
642+
}
643+
644+
[Fact]
645+
public void ThrowsIfReceivedDataNotAnArray()
646+
{
647+
var bytes = PackBytes(packer => packer.PackMapHeader(1)
648+
.Pack(nameof(ValueWrapper<Union<int, double>>.Value))
649+
.Pack("String"));
650+
651+
var ex = Assert.Throws<DeserialisationException>(() => new Deserialiser<ValueWrapper<Union<int, string>>>().Deserialise(bytes));
652+
653+
Assert.Equal(@"Union values must be encoded as an array for property ""value"" of type ""Dasher.Union`2[System.Int32,System.String]""",
654+
ex.Message);
655+
}
656+
582657
#region Helper
583658

584659
private static byte[] PackBytes(Action<MsgPack.Packer> packAction)

Dasher.Tests/TypeSupportTests.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ public void SupportsTuple3()
209209
Test(Tuple.Create(1, "Hello", true), packer => packer.PackArrayHeader(3).Pack(1).Pack("Hello").Pack(true));
210210
}
211211

212+
[Fact]
213+
public void SupportsUnion()
214+
{
215+
Test(Union<int, double>.Create(123), packer => packer.PackArrayHeader(2).Pack("Int32").Pack(123));
216+
Test(Union<int, double>.Create(123.0), packer => packer.PackArrayHeader(2).Pack("Double").Pack(123.0));
217+
Test(Union<int, string>.Create(null), packer => packer.PackArrayHeader(2).Pack("String").PackNull());
218+
}
219+
212220
#region Helper
213221

214222
private static T Test<T>(T value, Action<MsgPack.Packer> packAction)

Dasher.Tests/UnionProviderTests.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Dasher.TypeProviders;
4+
using Xunit;
5+
6+
namespace Dasher.Tests
7+
{
8+
public class UnionProviderTests
9+
{
10+
[Fact]
11+
public void TypeNames()
12+
{
13+
Assert.Equal("String", UnionProvider.GetTypeName(typeof(string)));
14+
Assert.Equal("Int32", UnionProvider.GetTypeName(typeof(int)));
15+
Assert.Equal("Version", UnionProvider.GetTypeName(typeof(Version)));
16+
Assert.Equal("Guid", UnionProvider.GetTypeName(typeof(Guid)));
17+
18+
Assert.Equal("Union<Int32,String>", UnionProvider.GetTypeName(typeof(Union<int, string>))); ;
19+
20+
Assert.Equal("Dasher.Tests.ValueWrapper<String>", UnionProvider.GetTypeName(typeof(ValueWrapper<string>)));
21+
22+
Assert.Equal("Dasher.Tests.UnionProviderTests", UnionProvider.GetTypeName(typeof(UnionProviderTests)));
23+
24+
Assert.Equal("[String]", UnionProvider.GetTypeName(typeof(IReadOnlyList<string>)));
25+
26+
Assert.Equal("(Int32=>Boolean)", UnionProvider.GetTypeName(typeof(IReadOnlyDictionary<int, bool>)));
27+
28+
Assert.Equal("(Int32=>[Boolean])", UnionProvider.GetTypeName(typeof(IReadOnlyDictionary<int, IReadOnlyList<bool>>)));
29+
Assert.Equal("[(Int32=>Boolean)]", UnionProvider.GetTypeName(typeof(IReadOnlyList<IReadOnlyDictionary<int, bool>>)));
30+
}
31+
}
32+
}

Dasher/Dasher.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
<Compile Include="TypeProviders\ReadOnlyListProvider.cs" />
6161
<Compile Include="TypeProviders\TimeSpanProvider.cs" />
6262
<Compile Include="TypeProviders\TupleProvider.cs" />
63+
<Compile Include="TypeProviders\UnionProvider.cs" />
6364
<Compile Include="TypeProviders\VersionProvider.cs" />
6465
<Compile Include="UnexpectedFieldBehaviour.cs" />
6566
<Compile Include="MsgPack\Unpacker.cs" />

Dasher/DasherContext.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ public DasherContext(IEnumerable<ITypeProvider> typeProviders = null)
5656
new ReadOnlyListProvider(),
5757
new ReadOnlyDictionaryProvider(),
5858
new NullableValueProvider(),
59-
new TupleProvider()
59+
new TupleProvider(),
60+
new UnionProvider()
6061
};
6162

6263
if (typeProviders == null)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Reflection;
5+
using System.Reflection.Emit;
6+
7+
namespace Dasher.TypeProviders
8+
{
9+
public sealed class UnionProvider : ITypeProvider
10+
{
11+
// Union types are serialised as an array of two values:
12+
// The string name of the type, including namespace and any generic type parameters
13+
// The serialised value, as per regular Dasher serialisation
14+
15+
public bool CanProvide(Type type)
16+
=> type.IsGenericType &&
17+
type.GetGenericTypeDefinition().Namespace == nameof(Dasher) &&
18+
type.GetGenericTypeDefinition().Name.StartsWith($"{nameof(Union<int, int>)}`");
19+
20+
public void Serialise(ILGenerator ilg, LocalBuilder value, LocalBuilder packer, LocalBuilder contextLocal, DasherContext context)
21+
{
22+
// write header
23+
ilg.Emit(OpCodes.Ldloc, packer);
24+
ilg.Emit(OpCodes.Ldc_I4_2);
25+
ilg.Emit(OpCodes.Call, typeof(UnsafePacker).GetMethod(nameof(UnsafePacker.PackArrayHeader)));
26+
27+
// TODO might be faster if we a generated class having members for use with called 'Union<>.Match'
28+
29+
var typeObj = ilg.DeclareLocal(typeof(Type));
30+
ilg.Emit(OpCodes.Ldloc, value);
31+
ilg.Emit(OpCodes.Callvirt, value.LocalType.GetProperty(nameof(Union<int, int>.Type)).GetMethod);
32+
ilg.Emit(OpCodes.Stloc, typeObj);
33+
34+
// write type name
35+
ilg.Emit(OpCodes.Ldloc, packer);
36+
ilg.Emit(OpCodes.Ldloc, typeObj);
37+
ilg.Emit(OpCodes.Call, typeof(UnionProvider).GetMethod(nameof(GetTypeName), BindingFlags.Static | BindingFlags.Public));
38+
ilg.Emit(OpCodes.Call, typeof(UnsafePacker).GetMethod(nameof(UnsafePacker.Pack), new[] { typeof(string) }));
39+
40+
// loop through types within the union, looking for a match
41+
var doneLabel = ilg.DefineLabel();
42+
var labelNextType = ilg.DefineLabel();
43+
foreach (var type in value.LocalType.GetGenericArguments())
44+
{
45+
ilg.LoadType(type);
46+
ilg.Emit(OpCodes.Ldloc, typeObj);
47+
ilg.Emit(OpCodes.Call, typeof(object).GetMethod(nameof(object.Equals), BindingFlags.Static | BindingFlags.Public));
48+
49+
// continue if this type doesn't match the union's values
50+
ilg.Emit(OpCodes.Brfalse, labelNextType);
51+
52+
// we have a match
53+
54+
// get the value
55+
var valueObj = ilg.DeclareLocal(type);
56+
ilg.Emit(OpCodes.Ldloc, value);
57+
ilg.Emit(OpCodes.Callvirt, value.LocalType.GetProperty(nameof(Union<int, int>.Value)).GetMethod);
58+
ilg.Emit(type.IsValueType ? OpCodes.Unbox_Any : OpCodes.Castclass, type);
59+
ilg.Emit(OpCodes.Stloc, valueObj);
60+
61+
// write value
62+
if (!context.TrySerialise(ilg, valueObj, packer, contextLocal))
63+
throw new Exception($"Unable to serialise type {type}");
64+
65+
ilg.Emit(OpCodes.Br, doneLabel);
66+
67+
ilg.MarkLabel(labelNextType);
68+
labelNextType = ilg.DefineLabel();
69+
}
70+
71+
ilg.MarkLabel(labelNextType);
72+
73+
ilg.Emit(OpCodes.Ldstr, "No match on union type");
74+
ilg.Emit(OpCodes.Newobj, typeof(Exception).GetConstructor(new[] {typeof(string)}));
75+
ilg.Emit(OpCodes.Throw);
76+
77+
ilg.MarkLabel(doneLabel);
78+
}
79+
80+
public void Deserialise(ILGenerator ilg, string name, Type targetType, LocalBuilder value, LocalBuilder unpacker, LocalBuilder contextLocal, DasherContext context, UnexpectedFieldBehaviour unexpectedFieldBehaviour)
81+
{
82+
// read the array length
83+
var count = ilg.DeclareLocal(typeof(int));
84+
ilg.Emit(OpCodes.Ldloc, unpacker);
85+
ilg.Emit(OpCodes.Ldloca, count);
86+
ilg.Emit(OpCodes.Call, typeof(Unpacker).GetMethod(nameof(Unpacker.TryReadArrayLength)));
87+
88+
var lbl0 = ilg.DefineLabel();
89+
ilg.Emit(OpCodes.Brtrue, lbl0);
90+
{
91+
ilg.Emit(OpCodes.Ldstr, "Union values must be encoded as an array for property \"{0}\" of type \"{1}\"");
92+
ilg.Emit(OpCodes.Ldstr, name);
93+
ilg.LoadType(value.LocalType);
94+
ilg.Emit(OpCodes.Call, typeof(string).GetMethod(nameof(string.Format), new[] { typeof(string), typeof(object), typeof(object) }));
95+
ilg.LoadType(targetType);
96+
ilg.Emit(OpCodes.Newobj, typeof(DeserialisationException).GetConstructor(new[] { typeof(string), typeof(Type) }));
97+
ilg.Emit(OpCodes.Throw);
98+
}
99+
ilg.MarkLabel(lbl0);
100+
101+
// ensure we have two items in the array
102+
var readValueLabel = ilg.DefineLabel();
103+
ilg.Emit(OpCodes.Ldloc, count);
104+
ilg.Emit(OpCodes.Ldc_I4_2);
105+
ilg.Emit(OpCodes.Beq, readValueLabel);
106+
{
107+
// throw due to incorrect number of items in Union array
108+
ilg.Emit(OpCodes.Ldstr, "Union array should have 2 elements (not {0}) for property \"{1}\" of type \"{2}\"");
109+
ilg.Emit(OpCodes.Ldloc, count);
110+
ilg.Emit(OpCodes.Box, typeof(int));
111+
ilg.Emit(OpCodes.Ldstr, name);
112+
ilg.LoadType(value.LocalType);
113+
ilg.Emit(OpCodes.Call, typeof(string).GetMethod(nameof(string.Format), new[] {typeof(string), typeof(object), typeof(object), typeof(object)}));
114+
ilg.LoadType(targetType);
115+
ilg.Emit(OpCodes.Newobj, typeof(DeserialisationException).GetConstructor(new[] {typeof(string), typeof(Type)}));
116+
ilg.Emit(OpCodes.Throw);
117+
}
118+
ilg.MarkLabel(readValueLabel);
119+
120+
// read the serialised type name
121+
var typeName = ilg.DeclareLocal(typeof(string));
122+
ilg.Emit(OpCodes.Ldloc, unpacker);
123+
ilg.Emit(OpCodes.Ldloca, typeName);
124+
ilg.Emit(OpCodes.Call, typeof(Unpacker).GetMethod(nameof(Unpacker.TryReadString), new[] { typeof(string).MakeByRefType() }));
125+
126+
var lbl1 = ilg.DefineLabel();
127+
ilg.Emit(OpCodes.Brtrue, lbl1);
128+
{
129+
ilg.Emit(OpCodes.Ldstr, "Unable to read union type name for property \"{0}\" of type \"{1}\"");
130+
ilg.Emit(OpCodes.Ldstr, name);
131+
ilg.LoadType(value.LocalType);
132+
ilg.Emit(OpCodes.Call, typeof(string).GetMethod(nameof(string.Format), new[] { typeof(string), typeof(object), typeof(object) }));
133+
ilg.LoadType(targetType);
134+
ilg.Emit(OpCodes.Newobj, typeof(DeserialisationException).GetConstructor(new[] { typeof(string), typeof(Type) }));
135+
ilg.Emit(OpCodes.Throw);
136+
}
137+
ilg.MarkLabel(lbl1);
138+
139+
// loop through types within the union, looking for a matching type name
140+
var doneLabel = ilg.DefineLabel();
141+
var labelNextType = ilg.DefineLabel();
142+
foreach (var type in value.LocalType.GetGenericArguments())
143+
{
144+
var expectedTypeName = GetTypeName(type);
145+
146+
ilg.Emit(OpCodes.Ldloc, typeName);
147+
ilg.Emit(OpCodes.Ldstr, expectedTypeName);
148+
ilg.Emit(OpCodes.Call, typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] {typeof(string),typeof(string)}, null));
149+
150+
// continue if this type doesn't match the union's values
151+
ilg.Emit(OpCodes.Brfalse, labelNextType);
152+
153+
// we have a match
154+
// read the value
155+
var readValue = ilg.DeclareLocal(type);
156+
if (!context.TryDeserialise(ilg, name, targetType, readValue, unpacker, contextLocal, unexpectedFieldBehaviour))
157+
throw new Exception($"Unable to deserialise values of type {type} from MsgPack data.");
158+
159+
// create the union
160+
ilg.Emit(OpCodes.Ldloc, readValue);
161+
ilg.Emit(OpCodes.Call, value.LocalType.GetMethod(nameof(Union<int, int>.Create), new[] {type}));
162+
163+
// store it in the result value
164+
ilg.Emit(OpCodes.Stloc, value);
165+
166+
// exit the loop
167+
ilg.Emit(OpCodes.Br, doneLabel);
168+
169+
ilg.MarkLabel(labelNextType);
170+
labelNextType = ilg.DefineLabel();
171+
}
172+
173+
ilg.MarkLabel(labelNextType);
174+
175+
// TODO include received type name in error message and some more general info
176+
ilg.Emit(OpCodes.Ldstr, "No match on union type");
177+
ilg.Emit(OpCodes.Newobj, typeof(Exception).GetConstructor(new[] { typeof(string) }));
178+
ilg.Emit(OpCodes.Throw);
179+
180+
ilg.MarkLabel(doneLabel);
181+
}
182+
183+
public static string GetTypeName(Type type)
184+
{
185+
if (!type.IsGenericType)
186+
return type.Namespace == nameof(System) ? type.Name : type.FullName;
187+
188+
var arguments = type.GetGenericArguments();
189+
if (arguments.Length == 1 && type.GetGenericTypeDefinition() == typeof(IReadOnlyList<>))
190+
return $"[{GetTypeName(arguments[0])}]";
191+
if (arguments.Length == 2 && type.GetGenericTypeDefinition() == typeof(IReadOnlyDictionary<,>))
192+
return $"({GetTypeName(arguments[0])}=>{GetTypeName(arguments[1])})";
193+
194+
var baseName = type.FullName.StartsWith("Dasher.Union`")
195+
? "Union"
196+
: type.FullName.Substring(0, type.FullName.IndexOf('`'));
197+
198+
return $"{baseName}<{string.Join(",", arguments.Select(GetTypeName))}>";
199+
}
200+
}
201+
}

0 commit comments

Comments
 (0)