1 module cerealed.cereal;
2 
3 import cerealed.traits: isCereal, isCerealiser, isDecerealiser;
4 import std.traits; // too many to bother listing
5 import std.range: isInputRange, isOutputRange, isInfinite;
6 
7 class CerealException: Exception {
8     this(string msg, string file = __FILE__, size_t line = __LINE__, Throwable next = null) @safe pure {
9         super(msg, file, line, next);
10     }
11 }
12 
13 enum CerealType { WriteBytes, ReadBytes };
14 
15 void grain(C, T)(auto ref C cereal, ref T val) if(isCereal!C && is(T == ubyte)) {
16     cereal.grainUByte(val);
17 }
18 
19 //catch all signed numbers and forward to reinterpret
20 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && !is(T == enum) &&
21                                                         (isSigned!T || isBoolean!T ||
22                                                          is(T == char) || isFloatingPoint!T)) {
23     cereal.grainReinterpret(val);
24 }
25 
26 // If the type is an enum, get the unqualified base type and cast it to that.
27 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == enum)) {
28     import std.conv: text;
29 
30     alias BaseType = Unqual!(OriginalType!(T));
31     cereal.grain( cast(BaseType)val );
32     if(val < T.min || val > T.max)
33         throw new Exception(text("Illegal value (", val, ") for type ", T.stringof));
34 }
35 
36 
37 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && is(T == wchar)) {
38     cereal.grain(*cast(ushort*)&val);
39 }
40 
41 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && is(T == dchar)) {
42     cereal.grain(*cast(uint*)&val);
43 }
44 
45 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == ushort)) {
46     ubyte valh = (val >> 8);
47     ubyte vall = val & 0xff;
48     cereal.grainUByte(valh);
49     cereal.grainUByte(vall);
50     val = (valh << 8) + vall;
51 }
52 
53 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == uint)) {
54     ubyte val0 = (val >> 24);
55     ubyte val1 = cast(ubyte)(val >> 16);
56     ubyte val2 = cast(ubyte)(val >> 8);
57     ubyte val3 = val & 0xff;
58     cereal.grainUByte(val0);
59     cereal.grainUByte(val1);
60     cereal.grainUByte(val2);
61     cereal.grainUByte(val3);
62     val = (val0 << 24) + (val1 << 16) + (val2 << 8) + val3;
63 }
64 
65 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == ulong)) {
66     T newVal;
67     for(int i = 0; i < T.sizeof; ++i) {
68         immutable shiftBy = 64 - (i + 1) * T.sizeof;
69         ubyte byteVal = (val >> shiftBy) & 0xff;
70         cereal.grainUByte(byteVal);
71         newVal |= (cast(T)byteVal << shiftBy);
72     }
73     val = newVal;
74 }
75 
76 enum hasByteElement(T) = is(Unqual!(ElementType!T): ubyte) && T.sizeof == 1;
77 
78 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCerealiser!C &&
79                                                            isInputRange!T && !isInfinite!T &&
80                                                            !is(T == string) &&
81                                                            !isStaticArray!T &&
82                                                            !isAssociativeArray!T) {
83     grain!ushort(cereal, val);
84 }
85 
86 void grain(U, C, T)(auto ref C cereal, ref T val) @trusted if(isCerealiser!C &&
87                                                               isInputRange!T && !isInfinite!T &&
88                                                               !is(T == string) &&
89                                                               !isStaticArray!T &&
90                                                               !isAssociativeArray!T) {
91     import std.conv: text;
92     import std.array: array;
93     import std.range: hasSlicing;
94 
95     enum hasLength = is(typeof(() { auto l = val.length; }));
96     static assert(hasLength, text("Only InputRanges with .length accepted, not the case for ",
97                                   fullyQualifiedName!T));
98     U length = cast(U)val.length;
99     assert(length == val.length,
100            text(C.stringof, " overflow. Length: ", length, ". Val length: ", val.length, "\n",
101                val.array));
102     cereal.grain(length);
103 
104     static if(hasSlicing!(Unqual!T) && hasByteElement!T)
105         cereal.grainRaw(cast(ubyte[])val.array);
106     else
107         foreach(ref e; val) cereal.grain(e);
108 }
109 
110 
111 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && isStaticArray!T) {
112     static if(hasByteElement!T)
113         cereal.grainRaw(cast(ubyte[])val);
114     else
115         foreach(ref e; val) cereal.grain(e);
116 }
117 
118 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isDecerealiser!C &&
119                                                            !isStaticArray!T &&
120                                                            isOutputRange!(T, ubyte)) {
121     grain!ushort(cereal, val);
122 }
123 
124 void grain(U, C, T)(auto ref C cereal, ref T val) @trusted if(isDecerealiser!C &&
125                                                               !isStaticArray!T &&
126                                                               isOutputRange!(T, ubyte)) {
127     version(DigitalMars)
128         U length;
129     else
130         U length = void;
131 
132     cereal.grain(length);
133 
134     static if(isArray!T) {
135         decerealiseArrayImpl(cereal, val, length);
136     } else {
137         for(U i = 0; i < length; ++i) {
138             ubyte b = void;
139             cereal.grain(b);
140 
141             enum hasOpOpAssign = is(typeof(() { val ~= b; }));
142             static if(hasOpOpAssign) {
143                 val ~= b;
144             } else {
145                 val.put(b);
146             }
147         }
148     }
149 }
150 
151 private void decerealiseArrayImpl(C, T, U)(auto ref C cereal, ref T val, U length) @safe
152     if(is(T == E[], E) && isDecerealiser!C)
153 {
154 
155     import std.exception: enforce;
156     import std.conv: text;
157     import std.range: ElementType, isInputRange;
158     import std.traits: isScalarType;
159 
160     ulong neededBytes(T)(ulong length) {
161         alias E = ElementType!T;
162         static if(isScalarType!E)
163             return length * E.sizeof;
164         else static if(isInputRange!E)
165             return neededBytes!E(length);
166         else
167             return 0;
168     }
169 
170     immutable needed = neededBytes!T(length);
171     enforce(needed <= cereal.bytesLeft,
172             text("Not enough bytes left to decerealise ", T.stringof, " of ", length, " elements\n",
173                  "Bytes left: ", cereal.bytesLeft, ", Needed: ", needed, ", bytes: ", cereal.bytes));
174 
175     static if(hasByteElement!T) {
176         val = cereal.grainRaw(length).dup;
177     } else {
178         if(val.length != length) val.length = cast(uint)length;
179         assert(length == val.length, "overflow");
180 
181         foreach(ref e; val) cereal.grain(e);
182     }
183 }
184 
185 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isDecerealiser!C &&
186                                                            !isOutputRange!(T, ubyte) &&
187                                                            isDynamicArray!T && !is(T == string)) {
188     grain!ushort(cereal, val);
189 }
190 
191 void grain(U, C, T)(auto ref C cereal, ref T val) @trusted if(isDecerealiser!C &&
192                                                               !isOutputRange!(T, ubyte) &&
193                                                               isDynamicArray!T && !is(T == string)) {
194     version(DigitalMars)
195         U length;
196     else
197         U length = void;
198 
199     cereal.grain(length);
200     decerealiseArrayImpl(cereal, val, length);
201 }
202 
203 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && is(T == string)) {
204     grain!ushort(cereal, val);
205 }
206 
207 void grain(U, C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && is(T == string)) {
208     U length = cast(U)val.length;
209     assert(length == val.length, "overflow");
210     cereal.grain(length);
211 
212     static if(isCerealiser!C)
213         cereal.grainRaw(cast(ubyte[])val);
214     else
215         val = cast(string)cereal.grainRaw(length);
216 }
217 
218 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && isAssociativeArray!T) {
219     grain!ushort(cereal, val);
220 }
221 
222 void grain(U, C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && isAssociativeArray!T) {
223     U length = cast(U)val.length;
224     assert(length == val.length, "overflow");
225     cereal.grain(length);
226     const keys = val.keys;
227 
228     for(U i = 0; i < length; ++i) {
229         KeyType!T k = keys.length ? keys[i] : KeyType!T.init;
230         auto v = keys.length ? val[k] : ValueType!T.init;
231 
232         cereal.grain(k);
233         cereal.grain(v);
234         val[k] = v;
235     }
236 }
237 
238 void grain(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && isPointer!T) {
239     import std.traits;
240     alias ValueType = PointerTarget!T;
241     static if(isDecerealiser!C) {
242         if(val is null) val = new ValueType;
243     }
244     cereal.grain(*val);
245 }
246 
247 private template canCall(C, T, string func) {
248     enum canCall = is(typeof(() { auto cer = C(); auto val = T.init; mixin("val." ~ func ~ "(cer);"); }));
249     static if(!canCall && __traits(hasMember, T, func)) {
250         pragma(msg, "Warning: '" ~ func ~
251                "' function defined for ", T, ", but does not compile for Cereal ", C);
252     }
253 }
254 
255 void grain(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && isAggregateType!T &&
256                                                            !isInputRange!T && !isOutputRange!(T, ubyte)) {
257     enum canAccept   = canCall!(C, T, "accept");
258     enum canPreBlit = canCall!(C, T, "preBlit");
259     enum canPostBlit = canCall!(C, T, "postBlit");
260 
261     static if(canAccept) { //custom serialisation
262         static assert(!canPostBlit && !canPreBlit, "Cannot define both accept and pre/postBlit");
263         val.accept(cereal);
264     } else { //normal serialisation, go through each member and possibly serialise
265         static if(canPreBlit) {
266             val.preBlit(cereal);
267         }
268 
269         cereal.grainAllMembers(val);
270         static if(canPostBlit) { //semi-custom serialisation, do post blit
271             val.postBlit(cereal);
272         }
273     }
274 }
275 
276 void grainAllMembers(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == struct)) {
277     cereal.grainAllMembersImpl!T(val);
278 }
279 
280 
281 void grainAllMembers(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C && is(T == class)) {
282 
283     import std.conv: text;
284 
285     static if(isCerealiser!C) {
286         assert(val !is null, "null value cannot be serialised");
287     }
288 
289     enum hasDefaultConstructor = is(typeof(() { val = new T; }));
290     static if(hasDefaultConstructor && isDecerealiser!C) {
291         if(val is null) val = new T;
292     } else {
293         assert(val !is null, text("Cannot deserialise into null value. ",
294                                   "Possible cause: no default constructor for ",
295                                   fullyQualifiedName!T, "."));
296     }
297 
298     cereal.grainClass(val);
299 }
300 
301 
302 alias grainMemberWithAttr = grainAggregateMember;
303 void grainAggregateMember(string member, C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C) {
304 
305     import cerealed.attrs: NoCereal;
306     import std.meta: staticIndexOf;
307 
308     /**(De)serialises one member taking into account its attributes*/
309     enum noCerealIndex = staticIndexOf!(NoCereal, __traits(getAttributes,
310                                                            __traits(getMember, val, member)));
311     //only serialise if the member doesn't have @NoCereal or @PostBlit
312     static if(noCerealIndex == -1) {
313         grainMember!member(cereal, val);
314     }
315 }
316 
317 void grainMember(string member, C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C) {
318 
319     import cerealed.attrs:
320         isABitsStruct, isArrayLengthStruct, isLengthInBytesStruct, RawArray, isLengthType;
321     import std.meta: staticIndexOf, Filter;
322 
323     alias bitsAttrs = Filter!(isABitsStruct, __traits(getAttributes,
324                                                       __traits(getMember, val, member)));
325     static assert(bitsAttrs.length == 0 || bitsAttrs.length == 1,
326                   "Too many Bits!N attributes!");
327 
328     alias arrayLengths = Filter!(isArrayLengthStruct,
329                                  __traits(getAttributes,
330                                           __traits(getMember, val, member)));
331     static assert(arrayLengths.length == 0 || arrayLengths.length == 1,
332                   "Too many ArrayLength attributes");
333 
334     alias lengthInBytes = Filter!(isLengthInBytesStruct,
335                                   __traits(getAttributes,
336                                            __traits(getMember, val, member)));
337     static assert(lengthInBytes.length == 0 || lengthInBytes.length == 1,
338                   "Too many LengthInBytes attributes");
339 
340     enum rawArrayIndex = staticIndexOf!(RawArray, __traits(getAttributes,
341                                                            __traits(getMember, val, member)));
342 
343     alias lengthTypes = Filter!(isLengthType, __traits(getAttributes, __traits(getMember, val, member)));
344     static assert(lengthTypes.length == 0 || lengthTypes.length == 1,
345                   "Too many LengthType attributes");
346 
347     static if(bitsAttrs.length == 1) {
348 
349         grainWithBitsAttr!(member, bitsAttrs[0])(cereal, val);
350 
351     } else static if(lengthTypes.length == 1) {
352 
353         grain!(lengthTypes[0].Type)(cereal, __traits(getMember, val, member));
354 
355     } else static if(rawArrayIndex != -1) {
356 
357         cereal.grainRawArray(__traits(getMember, val, member));
358 
359     } else static if(arrayLengths.length > 0) {
360 
361         grainWithArrayLengthAttr!(member, arrayLengths[0].member)(cereal, val);
362 
363     } else static if(lengthInBytes.length > 0) {
364 
365         grainWithLengthInBytesAttr!(member, lengthInBytes[0].member)(cereal, val);
366 
367     } else {
368 
369         cereal.grain(__traits(getMember, val, member));
370 
371     }
372 }
373 
374 private void grainWithBitsAttr(string member, alias bitsAttr, C, T)(
375     auto ref C cereal, ref T val) @safe if(isCereal!C) {
376 
377     import cerealed.attrs: getNumBits;
378     import std.conv: text;
379 
380     enum numBits = getNumBits!(bitsAttr);
381     enum sizeInBits = __traits(getMember, val, member).sizeof * 8;
382     static assert(numBits <= sizeInBits,
383                   text(fullyQualifiedName!T, ".", member, " is ", sizeInBits,
384                        " bits long, which is not enough to store @Bits!", numBits));
385     cereal.grainBitsT(__traits(getMember, val, member), numBits);
386 }
387 
388 private void grainWithArrayLengthAttr(string member, string lengthMember, C, T)
389     (auto ref C cereal, ref T val) @safe if(isCereal!C) {
390 
391     import std.conv: text;
392     import std.range: ElementType;
393 
394     checkArrayAttrType!member(cereal, val);
395 
396     static if(isCerealiser!C) {
397         cereal.grainRawArray(__traits(getMember, val, member));
398     } else {
399         immutable length = lengthOfArray!(member, lengthMember)(cereal, val);
400         alias E = ElementType!(typeof(__traits(getMember, val, member)));
401 
402         if(length * E.sizeof  > cereal.bytesLeft) {
403             throw new CerealException(text("@ArrayLength of ", length, " units of type ",
404                                            E.stringof,
405                                            " (", length * E.sizeof, " bytes) ",
406                                            "larger than remaining byte array (",
407                                            cereal.bytesLeft, " bytes)\n",
408                                           cereal.bytes));
409         }
410 
411         mixin(q{__traits(getMember, val, member).length = length;});
412 
413         foreach(ref e; __traits(getMember, val, member)) cereal.grain(e);
414     }
415 }
416 
417 void grainWithLengthInBytesAttr(string member, string lengthMember, C, T)
418                                 (auto ref C cereal, ref T val) @safe if(isCereal!C) {
419 
420     import std.conv: text;
421     import std.range: ElementType;
422 
423     checkArrayAttrType!member(cereal, val);
424 
425     static if(isCerealiser!C) {
426         cereal.grainRawArray(__traits(getMember, val, member));
427     } else {
428         immutable length = lengthOfArray!(member, lengthMember)(cereal, val); //error handling
429 
430         if(length > cereal.bytesLeft) {
431             alias E = ElementType!(typeof(__traits(getMember, val, member)));
432             throw new CerealException(text("@LengthInBytes of ", length, " bytes ",
433                                            "larger than remaining byte array (",
434                                            cereal.bytesLeft, " bytes)"));
435         }
436 
437         __traits(getMember, val, member).length = 0;
438 
439         long bytesLeft = length;
440         while(bytesLeft) {
441             auto origCerealBytesLeft = cereal.bytesLeft;
442             __traits(getMember, val, member).length++;
443             cereal.grain(__traits(getMember, val, member)[$ - 1]);
444             bytesLeft -= (origCerealBytesLeft - cereal.bytesLeft);
445         }
446     }
447 }
448 
449 private void checkArrayAttrType(string member, C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C) {
450 
451     import std.conv: text;
452 
453     alias M = typeof(__traits(getMember, val, member));
454     static assert(is(M == E[], E),
455                   text("@ArrayLength and @LengthInBytes not valid for ", member,
456                        ": they can only be used on slices"));
457 }
458 
459 
460 private int lengthOfArray(string member, string lengthMember, C, T)(auto ref C cereal, ref T val)
461     @safe if(isCereal!C) {
462 
463     import std.conv: text;
464 
465     int _tmpLen;
466     mixin(q{with(val) _tmpLen = cast(int)(} ~ lengthMember ~ q{);});
467 
468     if(_tmpLen < 0)
469         throw new CerealException(text("@LengthInBytes resulted in negative length ", _tmpLen));
470 
471     return _tmpLen;
472 }
473 
474 void grainRawArray(C, T)(auto ref C cereal, ref T[] val) @trusted if(isCereal!C) {
475     //can't use virtual functions due to template parameter
476     static if(isDecerealiser!C) {
477         val.length = 0;
478         while(cereal.bytesLeft()) {
479             val.length++;
480             cereal.grain(val[$ - 1]);
481         }
482     } else {
483         foreach(ref t; val) cereal.grain(t);
484     }
485 }
486 
487 
488 /**
489  * To be used when the length of the array is known at run-time based on the value
490  * of a part of byte stream.
491  */
492 void grainLengthedArray(C, T)(auto ref C cereal, ref T[] val, long length) {
493     val.length = cast(typeof(val.length))length;
494     foreach(ref t; val) cereal.grain(t);
495 }
496 
497 
498 package void grainClassImpl(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == class)) {
499     //do base classes first or else the order is wrong
500     cereal.grainBaseClasses(val);
501     cereal.grainAllMembersImpl!T(val);
502 }
503 
504 private void grainBitsT(C, T)(auto ref C cereal, ref T val, int bits) @safe if(isCereal!C) {
505     uint realVal = val;
506     cereal.grainBits(realVal, bits);
507     val = cast(T)realVal;
508 }
509 
510 private void grainReinterpret(C, T)(auto ref C cereal, ref T val) @trusted if(isCereal!C) {
511     auto ptr = cast(CerealPtrType!T)(&val);
512     cereal.grain(*ptr);
513 }
514 
515 private void grainBaseClasses(C, T)(auto ref C cereal, ref T val) @safe if(isCereal!C && is(T == class)) {
516     foreach(base; BaseTypeTuple!T) {
517         cereal.grainAllMembersImpl!base(val);
518     }
519 }
520 
521 
522 private void grainAllMembersImpl(ActualType, C, ValType)
523                                 (auto ref C cereal, ref ValType val) @trusted if(isCereal!C) {
524     foreach(member; __traits(derivedMembers, ActualType)) {
525         //makes sure to only serialise members that make sense, i.e. data
526         enum isMemberVariable = is(typeof(() {
527                                            __traits(getMember, val, member) = __traits(getMember, val, member).init;
528                                        }));
529         static if(isMemberVariable) {
530             cereal.grainAggregateMember!member(val);
531         }
532     }
533 }
534 
535 private template CerealPtrType(T) {
536     static if(is(T == bool) || is(T == char)) {
537         alias CerealPtrType = ubyte*;
538     } else static if(is(T == float)) {
539         alias CerealPtrType = uint*;
540     } else static if(is(T == double)) {
541         alias CerealPtrType = ulong*;
542     } else {
543         alias CerealPtrType = Unsigned!T*;
544     }
545 }