Monday 20 July 2015

Calculating PI in C# using Monte-Carlo simulation

The following code sample shows numeric compuation of the number PI using Monte-Carlo simulation. First, a sequential approach is used. Then the Parallel.For construct in TPL is used. In the end, we use Tasks in TPL.


To compute PI we use the same approach. We consider the unit circle inscribed in a square around origo with corners at coordinates (-1,-1), (-1, 1), (1,1) and (1,-1). The number PI can be defined as generating random numbers and looking at the ratio of the numbers inside the circle M divided upon the total numbers generated N. We know that the following can then be expected:

(a) M / N = PI / 4

Why? Because the square has got an area equal to four, remember that the unit square got sides equal to the number 2 and its area is therefore 2 * 2 = 4. The unit circle got a radius of 1, hence its area is PI * 1^2 = PI. The ratio to be expected between the areas of the unit circle and unit rectangle therefore gives the formula above. We can further compute the approximated numeric value of PI equal to:

(b) PI = 4 * (M / N)

This expression (b) is directly from the previous expression (a)
Let's move on the code sample, review the code. I have included a screen shot at the end. The conclusion I got after testing showed after several runs shows that the sequential version runs in about 3.5 seconds on my eight core system with about half the time, about 1.8 seconds using Parallel.For - The last version using Tasks and Tasks.WaitAll give about 1.7 seconds and the quickest compuation, about twice as fast. The iterations I used in the demo was 80 million.
Here is the code written in C#:

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace MonteCarloPiApproximation
{
    class Program
    {
        static int numberOfCores = Environment.ProcessorCount; 
        static int iterations = 10000000 * numberOfCores;

        static void Main(string[] args)
        {
            Console.WriteLine("Monte Carlo numeric simulation of PI");
            Console.WriteLine("Iteration limit: " + iterations);
            Console.WriteLine("Number of processor cores on system: " + Environment.ProcessorCount);
            var sw = new Stopwatch();
            sw.Start();

            Console.WriteLine("\nMONTE CARLO SIMULATION");

            MonteCarloPiApproximationSerialSimulation();
            sw.Stop();

            Console.WriteLine("Serial simulation: (ms)" + sw.ElapsedMilliseconds);
            Console.WriteLine();

            sw.Restart();

            Console.WriteLine("\nMONTE CARLO SIMULATION");
            MonteCarloPiApproximationParallellForSimulation(); 

            sw.Stop();

            Console.WriteLine("Parallell simulation using Parallel.For: (ms)" + sw.ElapsedMilliseconds);
            Console.WriteLine();

            sw.Restart();

            Console.WriteLine("\nMONTE CARLO SIMULATION");

            MonteCarloPiApproximationParallelTasksSimulation(); 

            Console.WriteLine("Parallell simulation using parallell Tasks: (ms)" + sw.ElapsedMilliseconds);
            Console.WriteLine();

            sw.Stop();           

            Console.WriteLine("Press Enter Key");



            Console.ReadKey(); 
        }

        private static void MonteCarloPiApproximationParallelTasksSimulation()
        {
            double piApproximation = 0;
            int inCircle = 0;
            double x, y = 0;

            int[] localCounters = new int[numberOfCores];
            Task[] tasks = new Task[numberOfCores];

            for (int i = 0; i < numberOfCores; i++)
            {
                int procIndex = i; //closure capture 
                tasks[procIndex] = Task.Factory.StartNew(() =>
                {
                    int localCounterInside = 0;

                    Random rnd = new Random();

                    for (int j = 0; j < iterations / numberOfCores; j++)
                    {
                        x = rnd.NextDouble();
                        y = rnd.NextDouble();
                        if (Math.Sqrt(x * x + y * y) <= 1.0)
                            localCounterInside++;
                    } 
                    localCounters[procIndex] = localCounterInside;

                });               
            }

            Task.WaitAll(tasks);
            inCircle = localCounters.Sum(); 

            piApproximation = 4 * ((double)inCircle / (double)iterations);

            Console.WriteLine();
            Console.WriteLine("Approximated Pi = {0}", piApproximation.ToString("F8"));
           
        }      

        private static void MonteCarloPiApproximationParallellForSimulation()
        {
            double piApproximation = 0;
            int inCircle = 0;
            double x, y = 0;
                   
            Parallel.For(0, numberOfCores, new ParallelOptions{ MaxDegreeOfParallelism = numberOfCores }, i =>
            {
              
                int localCounterInside = 0;

                Random rnd = new Random(); 

                for (int j = 0; j < iterations / numberOfCores; j++)
                {
                    x = rnd.NextDouble();
                    y = rnd.NextDouble();
                    if (Math.Sqrt(x*x+y*y) <= 1.0)
                        localCounterInside++;                                                        
                }

                Interlocked.Add(ref inCircle, localCounterInside); 
                            
            }); 

            piApproximation = 4 * ((double)inCircle / (double)iterations);

            Console.WriteLine();
            Console.WriteLine("Approximated Pi = {0}", piApproximation.ToString("F8"));
            
        }

        private static void MonteCarloPiApproximationSerialSimulation()
        {
            double piApproximation = 0;
            int total = 0;
            int inCircle = 0; 
            double x,y = 0;
            Random rnd = new Random(); 

            while (total < iterations)
            {
                x = rnd.NextDouble(); 
                y = rnd.NextDouble();

                if ((Math.Sqrt(x*x+y*y) <= 1.0))
                    inCircle++;

                total++;                
                piApproximation =  4 * ((double)inCircle / (double)total); 
            } //while 


            Console.WriteLine();
            Console.WriteLine("Approximated Pi = {0}", piApproximation.ToString("F8"));

        }




    }
}


Wednesday 15 July 2015

Data block types in Task Parallel Library

The following code sample shows data block types in Task Parallel Libary (TPL). The code is written in C# and can be run in a simple Console Application. Dataflow in TPL makes it possible to build actor-based programming and orchestrate coarse-grained dataflow and pipelining tasks, maintaining robustness and supporting concurrency-enabled applications. This makes it easier to construct high-performance, low latency systems.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;

namespace TplDataFlowSimpleTests
{
    class Program
    {

        static void Main(string[] args)
        {
            ActionBlockPostingAndFaulting();

            //BufferBlockPostingAndReceiving();

            //BroadCastPostAndMultipleReceive();

            //WriteOnceBlockParallellPostsAndSingleReceive();

            //ActionBlockPostingCompleting();

            //TransformBlockPostingAndProducingOutput(); 

            //TransformManyBlockPostingsAndReceive(); 

            //BatchBlockSeveralBatches();

            //JoinBlockAritmeticCombinations(); 

            //BatchedJoinBlockPropagatingValuesAndException();

            Console.WriteLine("Press the any key to continue ..");
            Console.ReadKey();

        }

        private static void BatchedJoinBlockPropagatingValuesAndException()
        {
            Func<int, int> doWork = n =>
            {
                if (n < 0)
                    throw new ArgumentOutOfRangeException();
                return n;
            };

            var batchedJoinBlock = new BatchedJoinBlock<int, Exception>(7);

            foreach (int i in new int[] {5, 6, -7, -22, 13, 55, 0})
            {
                try
                {
                    batchedJoinBlock.Target1.Post(doWork(i));
                }
                catch (ArgumentOutOfRangeException e)
                {
                    batchedJoinBlock.Target2.Post(e);
                }
            }

            var results = batchedJoinBlock.Receive();

            foreach (int n in results.Item1)
            {
                Console.WriteLine(n);
            }

            foreach (Exception e in results.Item2)
            {
                Console.WriteLine(e.Message);
            }
        }

        private static void JoinBlockAritmeticCombinations()
        {
            var joinBlock = new JoinBlock<int, int, char>(new GroupingDataflowBlockOptions{ Greedy = true});

            joinBlock.Target1.Post(3);
            joinBlock.Target1.Post(6);

            joinBlock.Target2.Post(5);
            joinBlock.Target2.Post(4);

            joinBlock.Target3.Post('+');
            joinBlock.Target3.Post('-');

            for (int i = 0; i < 2; i++)
            {
                var data = joinBlock.Receive();

                switch (data.Item3)
                {
                    case '+':
                        Console.WriteLine("{0} + {1} = {2}", data.Item1, data.Item2, data.Item1 + data.Item2);
                        break;
                    case '-':
                        Console.WriteLine("{0} - {1} = {2}", data.Item1, data.Item2, data.Item1 - data.Item2);
                        break;
                    default:
                        Console.WriteLine("Unknown operator '{0}'.", data.Item3);
                        break;
                } //switch 
            } //for 

        }

        private static void BatchBlockSeveralBatches()
        {
            var batchBlock = new BatchBlock<int>(10);

            for (int i = 0; i < 13; i++)
            {
                batchBlock.Post(i);
            }

            batchBlock.Complete();


            Console.WriteLine("The elements of the first batch are: [{0}] ", string.Join(",", batchBlock.Receive()));
            Console.WriteLine("The elements of the second batch are: [{0}]", string.Join(",", batchBlock.Receive()));
        }

        private static void TransformManyBlockPostingsAndReceive()
        {
            var transformManyBlock = new TransformManyBlock<string, char>(s => s.ToCharArray());

            transformManyBlock.Post("Hello");
            transformManyBlock.Post("World");

            for (int i = 0; i < ("Hello" + "World").Length; i++)
            {
                Console.WriteLine(transformManyBlock.Receive());
            }
        }

        private static void TransformBlockPostingAndProducingOutput()
        {
            var transformBlock = new TransformBlock<int, double>(n => Math.Sqrt(n));

            transformBlock.Post(10);
            transformBlock.Post(20);
            transformBlock.Post(30);

            for (int i = 0; i < 3; i++)
            {
                Console.WriteLine(transformBlock.Receive());
            }
        }

        private static void ActionBlockPostingCompleting()
        {
            var cts = new CancellationTokenSource();
            var actionBlock = new ActionBlock<int>(n =>
            {
                Console.WriteLine(n);
                Thread.Sleep(1000);
                if (n > 50)
                    cts.Cancel();
            }, new ExecutionDataflowBlockOptions{ MaxDegreeOfParallelism = 2, MaxMessagesPerTask = 1, CancellationToken = cts.Token });

            for (int i = 0; i < 10; i++)
            {
                actionBlock.Post(i*10); 
         
            }

            actionBlock.Complete();
            try
            {
                actionBlock.Completion.Wait(cts.Token);
            }
            catch (OperationCanceledException ex)
            {
                Console.WriteLine(ex.Message);
            }
        }

        private static void WriteOnceBlockParallellPostsAndSingleReceive()
        {
            var writeOnceBlock = new WriteOnceBlock<string>(null);

            Parallel.Invoke(
                () => writeOnceBlock.Post("Message 1"),
                () => writeOnceBlock.Post("Message 2"),
                () => writeOnceBlock.Post("Message 3"));

            Console.WriteLine(writeOnceBlock.Receive());
        }

        private static void BroadCastPostAndMultipleReceive()
        {
            var broadcastBlock = new BroadcastBlock<double>(null);

            broadcastBlock.Post(Math.PI);

            for (int i = 0; i < 3; i++)
            {
                Console.WriteLine(broadcastBlock.Receive());
            }
        }

        private static void BufferBlockPostingAndReceiving()
        {
            var bufferBlock = new BufferBlock<int>();

            for (int i = 0; i < 3; i++)
            {
                bufferBlock.Post(i);
            }

            for (int i = 0; i < 4; i++)
            {
                try
                {
                    Console.WriteLine(bufferBlock.Receive(new TimeSpan(0, 0, 2)));
                }
                catch (TimeoutException tie)
                {
                    Console.WriteLine("Exception of type: {0} with message: {1}", tie.GetType().Name, tie.Message);
                }
            }
        }

        private static void ActionBlockPostingAndFaulting()
        {
            var throwIfNegative = new ActionBlock<int>(n =>
            {
                Console.WriteLine("n = {0}", n);
                if (n < 0)
                    throw new ArgumentOutOfRangeException();
            });

            throwIfNegative.Completion.ContinueWith(
                task => { Console.WriteLine("The status of the completion task is '{0}'", task.Status); });

            throwIfNegative.Post(0);
            throwIfNegative.Post(-1);
            throwIfNegative.Post(1);
            throwIfNegative.Post(2);

            throwIfNegative.Complete();

            try
            {
                throwIfNegative.Completion.Wait();
            }
            catch (AggregateException ae)
            {
                ae.Handle(e =>
                {
                    Console.WriteLine("Encountered {0}: {1}", e.GetType().Name, e.Message);
                    return true;
                });
            }

          
        }
    }
}

Monday 13 July 2015

Producer-consumer scenario with BlockingCollection of Task Parallell Library

The BlockingCollection of Task Parallel Library or TPL supports Producer-Consumer scenarios well. This is a scenario or pattern, where you want to start processing items as soon as they are available. BlockingCollection makes this easy, as long as you follow the convention. In this article, a simple scenario with five producers and one consumer is presented. We create an array of tasks (Task[]) that we wait on, using the Task.WaitAll(Task[]) method. This creates a barrier in our main method of which we call the CompleteAdding() method of BlockingCollection to signal that we have no more items to add. The consumer, which here is a single Task, will use standard foreach iteration and call the method GetConsumingEnumerable() on the BlockingCollection to get the collection to iterate over. Note that we will be inside the foreach loop until the CompleteAdding() method is called on the BlockingCollection, i.e. the iteration is halted until CompleteAdding() is called and no more items are available.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace TestBlockingCollectionSecond
{

    public class Node
    {

        public int ManagedThreadId { get; set; }

        public int Value { get; set; }

        public override string ToString()
        {
            return string.Format("Inside ManagedThreadId {0}, Value: {1}", ManagedThreadId, Value);
        }

    }

    class Program
    {
        static void Main(string[] args)
        {

            BlockingCollection<Node> bc = new BlockingCollection<Node>();

            Console.WriteLine("Producer-consumer scenario BlockingCollection.\nFive producers, one consumer scenario");

            var rnd = new Random();

            var producers = new List<Task>();

            for (int i = 0; i < 5; i++)
            {

                var producer = Task.Factory.StartNew(() =>
                {                  
                    for (int j = 0; j < 5; j++)
                    {
                        Thread.Sleep(rnd.Next(100,300));
                        bc.Add(new Node { ManagedThreadId = Thread.CurrentThread.ManagedThreadId, Value = rnd.Next(1, 100) });
                    }
                });

                producers.Add(producer);

            }

            var consumer = Task.Factory.StartNew(() =>
            {
                foreach (Node item in bc.GetConsumingEnumerable())
                {
                    Console.WriteLine(item);
                }

                Console.WriteLine("Consumer all done!");             
            });

            Task.WaitAll(producers.ToArray());

            bc.CompleteAdding();
       

            Console.WriteLine("Hit any key to exit program");
            Console.ReadKey();         


        }
    }
}


If you are not sure when to signal CompleteAdding(), you can keep taking items from the BlockingCollection, using one single consumer or multiple, but remember to catch InvalidOperationException, in case there are no more items available, that is - after CompleteAdding() method has been called on the BlockingCollection. Note that the while loop below is wrapped in a consumer Task that is once again started before the Task.WaitAll call on the producers array, to start consuming right away. Our exit condition here is the flag bc.IsAddingCompleted is set to true, i.e. a call to CompleteAdding is performed. The benefit of using the GetConsumingEnumerable() here is having not to deal with exceptions and boolean flag of completed adding items in the consumer Block, since calling the Take() method on the collection after the CompleteAdding() method is called will throw an InvalidOperationException.



            var consumer = Task.Factory.StartNew(() => {

                try{
                while (!bc.IsAddingCompleted){
                    Console.WriteLine("BlockingCollection element: " + bc.Take());
                }
                }
                Console.WriteLine("Consumer all done!");
                catch (InvalidOperationException ioe){
                    //Console.WriteLine(ioe.Message);
                }
            });

Wednesday 8 July 2015

Logging the SQL of Entity Framework exceptions in EF 6

If you use EF 6, it is possible to add logging functionality that will reveal why an exception in the data layer of your app or system occured and in addition creating a runnable SQL that you might try out in your testing/production environment inside a transaction that is rollbacked for quicker diagnosis-cause-fix cycle! First off, create a class that implements the interface IDbCommandInterceptor in the System.Data.Entity.Infrastructure.Interception namespace. This class is then added using in your ObjectContext / DbContext class (this is a usually a partial class that you can extend) using the DbInterception.Add method. I add this class in the static constructor of my factory class inside a try-catch block. The important part is that you call the DbInterception.Add method and instantiate the class you create. Let's consider a code example of this. I am only focusing on logging exceptions, other kind of interceptions can of course be performed. Here is the sample class for logging exceptions, I have replaced the namespaces of the system of mine with the more generic "Acme":

using System;
using System.Data;
using System.Data.Common;
using System.Data.Entity.Infrastructure.Interception;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using MySoftware.Common.Log;



namespace Acme.Data.EntityFramework
{
    
    /// <summary>
    /// Intercepts exceptions that is raised by the database running an operation and propagated to the eventlog for logging 
    /// </summary>
    public class AcmeDbCommandInterceptor : IDbCommandInterceptor
    {

        public void NonQueryExecuted(DbCommand command, DbCommandInterceptionContext<int> interceptionContext)
        {
            LogIfError(command, interceptionContext);
        }

        public void NonQueryExecuting(DbCommand command, DbCommandInterceptionContext<int> interceptionContext)
        {
            LogIfError(command, interceptionContext);          
        }

        public void ReaderExecuted(DbCommand command, DbCommandInterceptionContext<DbDataReader> interceptionContext)
        {
            LogIfError(command, interceptionContext);          
        }

        public void ReaderExecuting(DbCommand command, DbCommandInterceptionContext<DbDataReader> interceptionContext)
        {
            LogIfError(command, interceptionContext);
        }

        public void ScalarExecuted(DbCommand command, DbCommandInterceptionContext<object> interceptionContext)
        {
            LogIfError(command, interceptionContext);            
        }

        public void ScalarExecuting(DbCommand command, DbCommandInterceptionContext<object> interceptionContext)
        {
            LogIfError(command, interceptionContext);           
        }

        private void LogIfError<TResult>(DbCommand command, DbCommandInterceptionContext<TResult> interceptionContext)
        {
            try
            {
                if (interceptionContext != null && interceptionContext.Exception != null)
                {
                    bool isLogged = false;
                    try
                    {
                        LogInterpolatedEfQueryString(command, interceptionContext);
                        isLogged = true; 
                    }
                    catch (Exception err)
                    {
                        LogRawEfQueryString(command, interceptionContext, err);
                    }
                    if (!isLogged)
                        LogRawEfQueryString(command, interceptionContext, null);
                   
                }
            }
            catch (Exception err)
            {
                Debug.WriteLine(err.Message);
            }
        }

        /// <summary>
        /// Logs the raw EF query string 
        /// </summary>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="command"></param>
        /// <param name="interceptionContext"></param>
        /// <param name="err"></param>
        private static void LogRawEfQueryString<TResult>(DbCommand command, DbCommandInterceptionContext<TResult> interceptionContext,
            Exception err)
        {
            if (err != null)
                Debug.WriteLine(err.Message);

            string queryParameters = LogEfQueryParameters(command);

            EventLogProvider.Log(
                string.Format(
                    "Acme serverside DB operation failed: Exception: {0}. Parameters involved in EF query: {1}. SQL involved in EF Query: {2}",
                    Environment.NewLine + interceptionContext.Exception.Message,
                    Environment.NewLine + queryParameters,
                    Environment.NewLine + command.CommandText
                    ), EventLogProviderEnum.Warning);
        }

        /// <summary>
        /// Return a string with the list of EF query parameters 
        /// </summary>
        /// <param name="command"></param>
        /// <returns></returns>
        private static string LogEfQueryParameters(DbCommand command)
        {
            var sb = new StringBuilder(); 
            for (int i = 0; i < command.Parameters.Count; i++)
            {
                if (command.Parameters[i].Value != null)
                    sb.AppendLine(string.Format(@"Query param {0}: {1}", i, command.Parameters[i].Value));
            }
            return sb.ToString();
        }

        private static DbType[] QuoteRequiringTypes
        {
            get
            {
                return new[]
                {
                        DbType.AnsiString, DbType.AnsiStringFixedLength, DbType.String,
                            DbType.StringFixedLength, DbType.Date, DbType.Date, DbType.DateTime,
                            DbType.DateTime2, DbType.Guid
                };
            }
        }


        private static void LogInterpolatedEfQueryString<TResult>(DbCommand command,
            DbCommandInterceptionContext<TResult> interceptionContext)
        {
            var paramRegex = new Regex("@\\d+");
            string interpolatedSqlString = paramRegex.Replace(command.CommandText,
                m => GetInterpolatedString(command, m));

            EventLogProvider.Log(string.Format(
                "Acme serverside DB operation failed: Exception: {0}. SQL involved in EF Query: {1}",
                Environment.NewLine + interceptionContext.Exception.Message,
                Environment.NewLine + interpolatedSqlString),
                EventLogProviderEnum.Warning);

        }

        private static string GetInterpolatedString(DbCommand command, Match m)
        {
            try
            {
                int matchIndex;
                if (string.IsNullOrEmpty(m.Value))
                    return m.Value;
                int.TryParse(m.Value.Replace("@", ""), out matchIndex);
                    //Entity framework will usually build parametrized queries with @1, @2 and so on .. 
                if (matchIndex < 0 || matchIndex >= command.Parameters.Count)
                    return m.Value;

                //Ok matchIndex from here 
                DbParameter dbParameter = command.Parameters[matchIndex];
                var dbParameterValue = dbParameter.Value;
                if (dbParameterValue == null)
                    return m.Value;

                try
                {
                    return GetAdjustedDbParameterValue(dbParameter, dbParameterValue);
                }
                catch (Exception err)
                {
                    Debug.WriteLine(err.Message);
                }
            }
            catch (Exception err)
            {
                Debug.WriteLine(err.Message);
            }

            return m.Value;
        }

        /// <summary>
        /// There are some cases where one have to adjust the Db Parametre value in case it is a boolean 
        /// </summary>
        /// <param name="dbParameter"></param>
        /// <param name="dbParameterValue"></param>
        /// <returns></returns>
        private static string GetAdjustedDbParameterValue(DbParameter dbParameter, object dbParameterValue)
        {
            if (QuoteRequiringTypes.Contains(dbParameter.DbType))
                return string.Format("'{0}'", dbParameterValue); //Remember to put quotes on parameter value 

            if (dbParameter.DbType == DbType.Boolean)
            {
                bool dbParameterBitValue;
                bool.TryParse(dbParameterValue.ToString(), out dbParameterBitValue);
                return dbParameterBitValue ? "1" : "0"; //BIT
            }

            return dbParameterValue.ToString(); //Default case (not a quoted value and not a bit value)
        }
    }
}


The code above uses a class EventLogProvider that will record to the Event Log of the system running the Entity Framework code in the data layer, usually the application server of your system. Here is the relevant code for logging to the Eventlog:

using System.Diagnostics;
using System;
using Acme.Common.Security;

namespace Acme.Common.Log
{
    public enum EventLogProviderEnum
    {
        Warning, Error
    }

    public static class EventLogProvider
    {
        private static string _applicationSource = "Acme";
        private static string _applicationEventLogName = "Application";

        public static void Log(Exception exception, EventLogProviderEnum logEvent)
        {
            Log(ErrorUtil.ConstructErrorMessage(exception), logEvent);
        }

        public static void Log(string logMessage, EventLogProviderEnum logEvent)
        {
            try
            {
                if (!EventLog.SourceExists(_applicationSource))
                    EventLog.CreateEventSource(_applicationSource, _applicationEventLogName);

                EventLog.WriteEntry(_applicationSource, logMessage, GetEventLogEntryType(logEvent));
            }
            catch { } // If the event log is unavailable, don't crash.
        }

        private static EventLogEntryType GetEventLogEntryType(EventLogProviderEnum logEvent)
        {
            switch(logEvent)
            {
                case EventLogProviderEnum.Error:
                    return EventLogEntryType.Error;
                case EventLogProviderEnum.Warning:
                default:
                    return EventLogEntryType.Warning;
            }
        }
    }
}


It is also necessary to add an instance of the db interception class in your ObjectContext or DbContext class as noted. Example:

try {
DbInterception.Add(new AcmeDbCommandInterceptor());
}
catch (Exception err){
 //Log error here (consider using the EventLogProvider above for example)
}

The interpolated string will often be the one that is interesting when EF queries fail. Entity Framework (EF) uses stored procedures and parameters to prevent SQL injection attacks. To get a SQL you can actually run, you will usually interpolate the EF query CommandText and look at the parameters, that is named as @0, @1, @2 and so on.. I use a Regex here to search after this. Note that my code uses a lot of try-catch in case something goes wrong. You also do NOT want to run any heavy code here, as the DbInterception will run on ANY query. I only do further processing IF an exception has occured, to avoid bogging down the system with performance drain. I also first try to get the interpolated EF query string that I can run in my Production or Test environment, usually inside a BEGIN TRAN.. and ROLLBACK statement just to see why the database call failed. In addition, the code will try to log the Raw EF Query in form of logging the EF query CommandText and the command parameters, but without the interpolation technique. I have also done some adjustment, by adding single quotes around strings and considering booleans as the value 0 or 1 (BIT). This code is new and there might be some additional adjustments here. The bottom line to note here is that it is important to LOG the EF query SQL to INFER the real REASON why the SQL query FAILED, i.e. a quicker DIAGNOSE-INFER-FIX cycle leading to more success on your projects, if you use .NET and Entity Framework (version 6 or newer)!

Tuesday 26 May 2015

Recording network traffic between clients and WCF services to a database

Developers working with WCF services might have come accross the excellent tool Fiddler - which is a HTTP web debugging proxy. In some production environments, it could be interesting to be able to log network traffic in the form of requests and responses to database.There are several approaches possible to this, one is to implement a WCF Message Inspector, which will also act as a WCF servicebehavior.

RecordingMessageBehavior

The following class implements the two interfaces IDispatchMessageInspector, IServiceBehavior. Please note that the code contains some domain-specific production code. The code can be adjusted to work on other code bases of course. The core structure of the code is the required code for recording network traffic. The code below contains some claims-specific code handling that is not required, but perhaps your code base also use some similar code?
In other words, you have to adjusted the code below to make it work with your code base, but the core structure of the code should be of general interest for WCF developers.


using System;
using System.Configuration;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Description;
using System.ServiceModel.Dispatcher;

using Microsoft.IdentityModel.Claims;
using System.Diagnostics;

namespace Hemit.OpPlan.Service.Host
{
    public class RecordingMessageBehavior : IDispatchMessageInspector, IServiceBehavior
    {

        private readonly bool _useTrafficRecording;

        private static DateTime _lastClearTime = DateTime.Now; 

        private readonly ITrafficLogManager _trafficLogManager;


        public RecordingMessageBehavior()
        {
            _trafficLogManager = new TrafficLogManager(); 

            try
            {
                if (ConfigurationManager.AppSettings[Constants.UseTrafficRecording] != null)
                {
                    _useTrafficRecording = bool.Parse(ConfigurationManager.AppSettings[Constants.UseTrafficRecording]); 
                }
            }
            catch (Exception err)
            {
                Debug.WriteLine(err.Message);
            }
        }


        #region IDispatchMessageInspector Members

        public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
        {
            MessageBuffer buffer = request.CreateBufferedCopy(Int32.MaxValue);
            request = buffer.CreateMessage();
            Message messageToWorkWith = buffer.CreateMessage(); 

            if (IsTrafficRecordingDeactivated())
                return request;
            InspectMessage(messageToWorkWith, true);
            return request;
        }

        public void BeforeSendReply(ref Message reply, object correlationState)
        {
            MessageBuffer buffer = reply.CreateBufferedCopy(Int32.MaxValue);
            reply = buffer.CreateMessage();
            Message messageToWorkWith = buffer.CreateMessage(); 

            if (IsTrafficRecordingDeactivated())
                return;

            InspectMessage(messageToWorkWith, false);
        }

        private bool IsTrafficRecordingDeactivated()
        {
            return !_useTrafficRecording || !IsCurrentUserLoggedIn();
        }

        #endregion

        private void InspectMessage(Message message, bool isRequest)
        {
            try
            {
                string serializedContent = message.ToString();
                string remoteIp = null;

                if (isRequest)
                {
                    try
                    {
                        RemoteEndpointMessageProperty remoteEndpoint = message.Properties[RemoteEndpointMessageProperty.Name] as RemoteEndpointMessageProperty;
                        if (remoteEndpoint != null) remoteIp = remoteEndpoint.Address + ":" + remoteEndpoint.Port;
                    }
                    catch (Exception err)
                    {
                        Debug.WriteLine(err.Message);
                    }
                }

                var trafficLogData = new TrafficLogItemDataContract
                {
                    IsRequest = isRequest,
                    RequestIP = remoteIp,
                    Recorded = DateTime.Now,
                    SoapMessage = serializedContent,
                    ServiceMethod = GetSoapActionPart(message.Headers.Action, 1),
                    ServiceName = GetSoapActionPart(message.Headers.Action, 2)
                };

                _trafficLogManager.InsertTrafficLogItem(trafficLogData, ResolveCurrentClaimsIdentity());

                if (DateTime.Now.Subtract(_lastClearTime).TotalHours > 1)
                {
                    _trafficLogManager.ClearOldTrafficLog(ResolveCurrentClaimsIdentity()); //check if clearing old traffic logs older than an hour 
                    _lastClearTime = DateTime.Now; 
                }
            }
            catch (Exception ex)
            {
                InterfaceBinding.GetInstance<ILog>().WriteError(ex.Message, ex);

            }
        }

        private string GetSoapActionPart(string soapAction, int rightOffset)
        {
            if (string.IsNullOrEmpty(soapAction) || !soapAction.Contains("/"))
                return soapAction;

            string[] soapParts = soapAction.Split('/');
            int soapPartCount = soapParts.Length;
            if (rightOffset >= soapPartCount)
                return soapAction;
            else
                return soapParts[soapPartCount - rightOffset]; 
        }

        private IClaimsIdentity ResolveCurrentClaimsIdentity()
        {
            if (ServiceSecurityContext.Current != null && ServiceSecurityContext.Current.AuthorizationContext != null)
            {
                var claimsIdentity = ServiceSecurityContext.Current.PrimaryIdentity as IClaimsIdentity; //Retrieval of current claims identity from WCF ServiceSecurityContext
                return claimsIdentity;
            }
            return null;
        }

        private bool IsCurrentUserLoggedIn()
        {
            IClaimsIdentity claimsIdentity = ResolveCurrentClaimsIdentity();
            if (claimsIdentity == null || claimsIdentity.Claims == null || claimsIdentity.Claims.Count == 0 
                || claimsIdentity.FindClaim(WellKnownClaims.AuthorizedForOrganizationalUnit) == null)
                return false;
            return true;
        }

        #region IServiceBehavior Members

        public void AddBindingParameters(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase, System.Collections.ObjectModel.Collection<ServiceEndpoint> endpoints, BindingParameterCollection bindingParameters)
        {
        }

        public void ApplyDispatchBehavior(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase)
        {
            foreach (var channelDispatcherBase in serviceHostBase.ChannelDispatchers)
            {
                var chdisp = channelDispatcherBase as ChannelDispatcher;
                if (chdisp == null)
                    continue; 

                foreach (var endpoint in chdisp.Endpoints)
                {
                    endpoint.DispatchRuntime.MessageInspectors.Add(new RecordingMessageBehavior()); 
                }
            }
        }

        public void Validate(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase)
        {
        }

        #endregion
    }
}

To turn on and off this WCF recording message behavior, toggle the app setting UseTrafficRecording to true in web.config:



In the code above, the first thing to do in the two methods AfterReceiveRequest and BeforeSendReply, is to clone the messages. This is because WCF messages are just like messages in MSMQ message queues - they can be consumed once. Actually we create a copy of a copy here to avoid errors.

   public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
        {

            MessageBuffer buffer = request.CreateBufferedCopy(Int32.MaxValue);
            request = buffer.CreateMessage();
            Message messageToWorkWith = buffer.CreateMessage(); 

            if (IsTrafficRecordingDeactivated())
                return request;
            InspectMessage(messageToWorkWith, true);
            return request;
        }

        public void BeforeSendReply(ref Message reply, object correlationState)
        {
 
            MessageBuffer buffer = reply.CreateBufferedCopy(Int32.MaxValue);
            reply = buffer.CreateMessage();
            Message messageToWorkWith = buffer.CreateMessage(); 


            if (IsTrafficRecordingDeactivated())
                return;

            InspectMessage(messageToWorkWith, false);
        }


Retrieving the serialized SOAP XML contents of the WCF messsage is easy, just running the ToString() method on the Message object. In addition, the code above requires some data access layer logic that will persist the data to database. Using EntityFramework and SQL Server her, logging is performed to the table TrafficLog. Please note that the code below will retain 24 hours of network traffic. In a production environment, database will fill up quickly if the traffic and intensity of data transfer in the form of request and response from and to the clients by the WCF services, will quickly grow into Gigabytes (GB). Usually you only want to record traffic when security demands this, or for error tracking and diagnostics.






using System;
using System.Collections.Generic;
using System.Linq;

using Microsoft.IdentityModel.Claims;


namespace Hemit.OpPlan.Service.Implementation
{

    public class TrafficLogManager : ITrafficLogManager
    {

        public List<TrafficLogItemDataContract> GetTrafficLog(DateTime startWindow, DateTime endWindow, IClaimsIdentity claimsIdentity)
        {
            using (var dbContext = ObjectContextManager.DbContextFromClaimsPrincipal(claimsIdentity))
            {
                var trafficLogs = dbContext.TrafficLogs.Where(tl => tl.Recorded >= startWindow && tl.Recorded <= endWindow).ToList();
                return trafficLogs.ForEach(tl =>
                {
                    return CreateTrafficLogDataContract(tl);
                });
            }
        }

        public void ClearOldTrafficLog(IClaimsIdentity claimsIdentity)
        {
            AggregateExceptionExtensions.CallActionLogAggregateException(() =>
            {
                System.Threading.ThreadPool.QueueUserWorkItem((object state) =>
                {
                    try
                    {
                        //Run on new thread 
                        using (var dbContext = ObjectContextManager.DbContextFromClaimsPrincipal(claimsIdentity))
                        {
                            //Use your own DbContext or ObjectContext here (EF)
                            DateTime obsoleteLogDate = DateTime.Now.Date.AddDays(-1);

                            var trafficLogs = dbContext.TrafficLogs.Where(tl => tl.Recorded <= obsoleteLogDate).ToList();
                            foreach (var tl in trafficLogs)
                            {
                                dbContext.TrafficLogs.DeleteObject(tl);
                            }
                            dbContext.SaveChanges();
                        }
                    }
                    catch (Exception err)
                    {
                        InterfaceBinding.GetInstance<ILog>().WriteError(err.Message); 
                    }
                });
            });
        }

        public void InsertTrafficLogItem(TrafficLogItemDataContract trafficLogItemdata, IClaimsIdentity claimsIdentity)
        {

            //string defaultConnectionString =  ObjectContextManager.GetConnectionString();

            using (var dbContext = ObjectContextManager.DbContextFromClaimsPrincipal(claimsIdentity))
            {
                var trafficLog = new TrafficLog();
                MapTrafficLogFromDataContract(trafficLogItemdata, trafficLog);
                dbContext.TrafficLogs.AddObject(trafficLog);
                dbContext.SaveChanges(); 
            }
        }

        private static void MapTrafficLogFromDataContract(TrafficLogItemDataContract trafficLogItem, TrafficLog trafficLog)
        {
            trafficLog.Recorded = trafficLogItem.Recorded;
            trafficLog.RequestIP = trafficLogItem.RequestIP;
            trafficLog.ServiceName = trafficLogItem.ServiceName;
            trafficLog.ServiceMethod = trafficLogItem.ServiceMethod;
            trafficLog.SoapMessage = trafficLogItem.SoapMessage;
            trafficLog.TrafficLogId = trafficLogItem.TrafficLogId;
            trafficLog.IsRequest = trafficLogItem.IsRequest;
        }

        private static TrafficLogItemDataContract CreateTrafficLogDataContract(TrafficLog trafficLog)
        {
            return new TrafficLogItemDataContract
            {
                IsRequest = trafficLog.IsRequest,
                Recorded = trafficLog.Recorded,
                ServiceMethod = trafficLog.ServiceMethod,
                ServiceName = trafficLog.ServiceName,
                SoapMessage = trafficLog.SoapMessage,
                TrafficLogId = trafficLog.TrafficLogId
            };
        }

    }    
   
}


In addition, we need to add the WCF message inspector to our service. I use a ServiceHostFactory and a class implementing ServiceHost class and override the OnOpening method:

        protected override void OnOpening()
        {
            ..


            if (Description.Behaviors.Find<RecordingMessageBehavior>() == null)
                Description.Behaviors.Add(new RecordingMessageBehavior()); 


            base.OnOpening();
        }

Here we register the service behavior for each WCF service to intercept the network traffic, clone the messages as required and log the contents to a database table, effectively gaining a way to get an overview of the requests and responses and to perform security audits, inspections and logging capabilities. You will need to do some adjustments to the code above to make it work against your codebase, but I hope this article is of general interest. The capability of recording WCF network traffic between services and clients is a very powerful feature giving the developer teams more control of their running environments. In addition, this is an important security boost. However, note that in many production environments, huge amounts of data will accumulate. Therefore, this implementation will clear the contents of the traffic log table every 24 hours. The following SQL query is a very convenient one to display data in SQL Management Studio. An XML column can be displayed and read as a documents conveniently.



select top 100 convert(xml, SoapMessage) as Payload, * from trafficlog

Obviously adjusting the Recorded datestamp column critera will pinpoint the returned data from this table more effectively. Retrieving and analysing the data might be problematic, if the amount of data reaches several gigabytes. Here is a powershell to retrieve data in a chunked fashion. Here, the data is being retrieved as 1-minute sized segments, which in a specific case chunked data up into 40 MB sized XML-files. This makes it possible to search the data using a text editor.


Powershell for retrieving traffic log data contents in a chunked fashion


#Traffic Log Bulk Copy Util

$databaseName='MYDATABASE_INSTANCENAME'
$databaseServer='MYDATABASE_SERVER'
$dateStamp = '130415'
$starthours = 12
$minuteSegmentSize = 1
$minuteSegmentNumbers = 160
$dateToInvestigate = '13.04.2015 00:00:00'
$outputFolder='SOMESHARE_UNC_PATH'

Write-Host "Traffic Log Request Bulk Output Logger"


for($i=0; $i -le $minuteSegmentNumbers; $i++){
 $dt = [datetime]::ParseExact($dateToInvestigate, "dd.MM.yyyy hh:mm:ss", $null)
 $hoursToAdd = $i / 60
 $minutesToAdd = $i % 60
 $dt = $dt.AddHours($hoursToAdd + $starthours).AddMinutes($minutesToAdd)
 $dtEnd = $dt.AddMinutes($minuteSegmentSize).AddSeconds(59)
 $startWindow = $dt.ToString("yyyy-MM-dd HH:mm:ss")
 $endWindow = $dtEnd.ToString("yyyy-MM-dd HH:mm:ss")
 $destinationFile = Join-Path -Path $outputFolder $ChildPath 
 $destinationFile += "trafficLog_$dateStamp_bulkcopy_" + $i + ".xml"
 Write-Host $("StartWindow: " + $startWindow + " EndWindow: " + $endWindow + "Destination file: " + $destinationFile)


 $bcpCmd = "bcp ""SELECT SoapMessage FROM TrafficLog WHERE IsRequest=1 AND Recorded >= '$startWindow' and Recorded <= '$endWindow'"" queryout $destinationFile -S $databaseServer -d $databaseName -c -t"","" -r""\n"" -T -x"

  Write-Host $("Bulk copy cmd: ")
  Write-Host $bcpCmd
  Write-Host "GO!"

  Invoke-Expression $bcpCmd

  Write-Host "DONE!"
 
}

By following this article, you should be now ready to write your WCF message inspector and record network traffic to your database, for increased security boost, logging and security audit capabilities and better overview of the raw data involved in the form of SOAP XML. Please read my previous article, if you want to gain insight into how to compress SOAP XML and sending it as binary data (byte arrays), even with GZip compression! Happy coding!

Sunday 24 May 2015

Programatically compress data contracts saving bandwidth between clients and WCF

This article will explain how to achieve compression when transferring data contracts between a WCF service and a client. Many developers using WCF have used the tool Fiddler or similar to inspect the network traffic. The default transmission of data between services and clients are data contracts that are serialized using the SOAP XML protocol. Note that this will usually send data uncompressed as XML of course. This is not an issue for smaller data contracts, but when you start transmitting much data, data contracts can soon grow into MegaBytes (MB) of data and the TTF (Time-To-Transfer) gets noticably, esecially on low-bandwidth devices such as smart phones! It is possible to configure compression in IIS, but more control can be achieved with a mini-framework of mine I have developed and I will present next. First we need some code to be able to compress data. The following class, GzipByteArrayCompressionutility, uses the MemoryStream and BufferedStream in the System.IO namespace and the GZipStream class in the System.IO.Compression namespace:

GzipByteArrayCompressionUtility



using System;
using System.IO;
using System.IO.Compression;

namespace SomeAcme.SomeProduct.Common.Compression
{

    public static class GzipByteArrayCompressionUtility
    {

        private static readonly int bufferSize = 64 * 1024; //64kB

        public static byte[] Compress(byte[] inputData)
        {
            if (inputData == null)
                throw new ArgumentNullException("inputData must be non-null");

            using (var compressIntoMs = new MemoryStream())
            {
                using (var gzs = new BufferedStream(new GZipStream(compressIntoMs,
                 CompressionMode.Compress), bufferSize))
                {
                    gzs.Write(inputData, 0, inputData.Length);
                }
                return compressIntoMs.ToArray();
            }
        }

        public static byte[] Decompress(byte[] inputData)
        {
            if (inputData == null)
                throw new ArgumentNullException("inputData must be non-null");

            using (var compressedMs = new MemoryStream(inputData))
            {
                using (var decompressedMs = new MemoryStream())
                {
                    using (var gzs = new BufferedStream(new GZipStream(compressedMs,
                     CompressionMode.Decompress), bufferSize))
                    {
                        gzs.CopyTo(decompressedMs);
                    }
                    return decompressedMs.ToArray();
                }
            }
        }

        //private static void Pump(Stream input, Stream output)
        //{
        //    byte[] bytes = new byte[4096];
        //    int n;
        //    while ((n = input.Read(bytes, 0, bytes.Length)) != 0)
        //    {
        //        output.Write(bytes, 0, n); 
        //    }
        //}

    }


}



Further, it is necessary to have some utility methods to handle data contracts and performing both the compression and decompression. I have also added some handy methods in this class for cloning data contracts.

DataContractUtility



using System;
using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Runtime.Serialization;
using System.Text;
using System.Xml;
using System.Xml.Serialization;


namespace SomeAcme.SomeProduct.Common
{

    /// <summary>
    /// Utility methods for serializing, deserializing and cloning data contracts
    /// </summary>
    /// <remarks>The serialization, deserialization and cloning here is 
    /// Limited to data contracts, as DataContractSerializer is being used for these operations</remarks>
    public static class DataContractUtility
    {

        public static byte[] ToByteArray<T>(T dataContract) where T : class
        {
            if (dataContract == null)
            {
                throw new ArgumentNullException();
            }

            using (MemoryStream memoryStream = new MemoryStream())
            {
                var serializer = new DataContractSerializer(typeof(T));
                serializer.WriteObject(memoryStream, dataContract);
                memoryStream.Position = 0;
                return memoryStream.ToArray();                
            }
        }

        public static List<TResult> GetDataContractsFromDataContractContainer<TContainer, TResult>(TContainer container, CompressionSetupDataContract setup) 
            where TContainer : class, IDataContractContainer<TResult>  
            where TResult : class
        {
            if (container == null)
                return null;
            if (!container.IsBinary || !IsUseNetworkCompression(setup))
                return container.SerialPayload;
            else
            {
                byte[] unzippedData = GzipByteArrayCompressionUtility.Decompress(container.BinaryPayload);
                var dataContracts = DataContractUtility.DeserializeFromByteArray<List<TResult>>(unzippedData);
                return dataContracts;
            }
        }

        private static bool IsUseNetworkCompression(CompressionSetupDataContract setup)
        {
            if (setup != null)
                return setup.UseNetworkCompression;

            if (ConfigurationManager.AppSettings[Constants.UseNetworkCompression] == null)
                return false; 

            bool isUseNetworkCompression = false;
            if (!bool.TryParse(ConfigurationManager.AppSettings[Constants.UseNetworkCompression], out isUseNetworkCompression))
                return false;
            else
                return isUseNetworkCompression;            
        }

        private static int GetNetworkCompressionItemTreshold(CompressionSetupDataContract setup)
        {
            if (setup != null)
                return setup.NetworkCompressionItemTreshold; 

            int standardNetworkCompressionItemTreshold = 100;
            if (ConfigurationManager.AppSettings[Constants.NetworkCompressionItemTreshold] == null)
                return standardNetworkCompressionItemTreshold;
            int networkCompressionItemTreshold = standardNetworkCompressionItemTreshold;
            if (!int.TryParse(ConfigurationManager.AppSettings[Constants.NetworkCompressionItemTreshold], out networkCompressionItemTreshold))
                return standardNetworkCompressionItemTreshold;
            else
                return networkCompressionItemTreshold;
        }



        public static TContainer GetContainerInstanceAfterResult<TContainer, TResult>(List<TResult> result, CompressionSetupDataContract setup) 
            where TContainer : class, IDataContractContainer<TResult>, new()
            where TResult : class
        {
            TContainer container = new TContainer();

            if (IsUseNetworkCompression(setup) && (result != null && result.Count >= Math.Min(container.BinaryTreshold, GetNetworkCompressionItemTreshold(setup))))
            {
                if (container.IsGzipped)
                    container.BinaryPayload = GzipByteArrayCompressionUtility.Compress(DataContractUtility.ToByteArray(result));
                else 
                    container.BinaryPayload = DataContractUtility.ToByteArray(result); 
                
                container.IsBinary = true;
            }
            else
            {
                container.SerialPayload = result;
            }

            return container; 
        }

        public static string SerializeObject<T>(T dataContract) where T : class
        {
            using (var memoryStream = new MemoryStream())
            {
                using (var streamReader = new StreamReader(memoryStream))
                {
                    var serializer = new DataContractSerializer(typeof(T));
                    serializer.WriteObject(memoryStream, dataContract);
                    memoryStream.Position = 0;
                    return streamReader.ReadToEnd();
                }
            }
        }

        public static T DeserializeObject<T>(string serializedContent) where T : class
        {
            return DeserializeObject<T>(serializedContent, Encoding.UTF8);
        }

        public static T DeserializeObject<T>(string serializedContent, Encoding encoding) where T : class
        {
            T result = null;
            using (Stream memoryStream = new MemoryStream())
            {
                var serializer = new DataContractSerializer(typeof(T));
                byte[] data = encoding.GetBytes(serializedContent);
                memoryStream.Write(data, 0, data.Length);
                memoryStream.Position = 0;
                result = (T)serializer.ReadObject(memoryStream);
            }
            return result;
        }

        public static T DeserializeFromByteArray<T>(byte[] data) where T : class
        {
            T result = null;
            using (Stream memoryStream = new MemoryStream())
            {
                var serializer = new DataContractSerializer(typeof(T));
                memoryStream.Write(data, 0, data.Length);
                memoryStream.Position = 0;
                result = (T)serializer.ReadObject(memoryStream);
            }
            return result;
        }

        public static T CloneObject<T>(T dataContract) where T : class
        {
            T result = null;
            var serializedContent = SerializeObject<T>(dataContract);
            result = DeserializeObject<T>(serializedContent);
            return result;
        }

    }

}



Note that I use two app settings here to control the config of compression in web.config:

    <add key="UseNetworkCompression" value="true" />
    <add key="NetworkCompressionItemTreshold" value="100" />

The appsetting UseNetworkCompression turns on and off the compression. If compression is turned off, we can switch to default SOAP XML serialization. But if it is turned on, the data will be packet into a compressed byte array. The appsetting NetworkCompressionItemTreshold is the minimum number of data contract items that will trigger compression. Here we use two constants also:

        public const string UseNetworkCompression = "UseNetworkCompression";
        public const string NetworkCompressionItemTreshold = "NetworkCompressionItemTreshold";

That was some of code to handle the compression and decompression of data contracts, next I show sample code how to use this. First we need to create a "container data class" for a demo of this mini framework. Each class that will be a "container data contract" will implement the interface IDataContractContainer:

using System.Collections.Generic;



    public interface IDataContractContainer<TResult> where TResult : class
    {

        /// 
        /// Set to true if the binary payload is to be used
        /// 
        bool IsBinary { get; set; }

        /// 
        /// If not is binary, the serial payload contains the data (SOAP XML based data contracts)
        /// 
        List<TResult> SerialPayload { get; set; }

        /// 
        /// Byte array which is the binary payload. Usually the byte array is also Gzipped.
        /// 
        byte[] BinaryPayload { get; set; }

        /// 
        /// The limit when the use of binary payload should be honored.
        /// 
        int BinaryTreshold { get; set; }

        /// 
        /// If true, the binary payload is Gzipped (compressed)
        /// 
        bool IsGzipped { get; set; }

    }



The following class implements the IDataContractContainer interface:

using System.Collections.Generic;
using System.Runtime.Serialization; 


   
    [DataContract(Namespace=Constants.DataContractNamespace20091001)]
    public class OperationItemsContainerDataContract : IDataContractContainer<OperationItemDataContract>
    {

        public OperationItemsContainerDataContract()
        {
            BinaryTreshold = 100;
            IsGzipped = true;
        }

        [DataMember(Order = 1)]
        public byte[] BinaryPayload { get; set; }

        [DataMember(Order = 2)]
        public List<OperationItemDataContract> SerialPayload { get; set; }

        [DataMember(Order = 3)]
        public bool IsBinary { get; set; }

        [DataMember(Order = 4)]
        public int BinaryTreshold { get; set; }

        [DataMember(Order = 5)]
        public bool IsGzipped { get; set; }

    }




Next up is some example code how to use this code:


        public OperationItemsContainerDataContract GetOperationItems(OperationsRequestDataContract request)
        {

..
            var operations = SomeManager.GetSomeData(rquest);
..


            OperationItemsContainerDataContract result =
                DataContractUtility.GetContainerInstanceAfterResult<OperationItemsContainerDataContract, OperationItemDataContract>(operations, null);
            
            return result;
        }


The code above includes some production code, but what matters here is the call to DataContractUtility.GetContainerInstanceAfterResult. This will compress the data contracts into a byte array, if the container instance defines this. Note that the compression will actually pack the data contract items into a byte array and apply GZip compression if the container class set this up. See the constructor of the OperationItemsContainer class, where we set GZip compression to be used if we have activated compression in web.config and the result set is above the limit set up in web.config On the client side, it is necessary retrieve the data. On the server side we have used the GetContainerInstanceAfterResult method, we will now use the GetDataContractsFromDataContractContainer method.

class OperationItemProvider:

  private void LoadDailyOperationsByCurrentTheater(Collection operationItems)
        {
           
            try
            {
                var container = SomeAgent.GetSomeOperationItems(request);


                var operations = DataContractUtility.GetDataContractsFromDataContractContainer<OperationItemsContainerDataContract, OperationItemDataContract>(container, Context.CompressionSetup);

        
            }
            catch (SomeAcmeClientException ex)
            {
                DispatcherUtil.InvokeAction(() => EventAggregator.GetEvent<ErrorMessageEvent>().Publish(new ErrorMessageEventArg { ErrorMessage = ex.Message }));
            }

..
        
        }

Note that the code above includes some production code. The important part is the use of the method GetDataContractsFromDataContractContainer. The object Context.CompressionSetup contains our web.config and is a simple retrieval of the configured settings on our serverside (the serverside method uses the ConfigurationManager. Here is the WCF server method:


        public CompressionSetupDataContract GetCompressionSetup()
        {
            bool standardUseNetworkCompression = false;
            int standardNetworkCompressionItemTreshold = 100;       
           
            var setup = new CompressionSetupDataContract
            {
                UseNetworkCompression = standardUseNetworkCompression, 
                NetworkCompressionItemTreshold = standardNetworkCompressionItemTreshold
            };
            if (ConfigurationManager.AppSettings[Constants.UseNetworkCompression] != null)
            {
                bool useNetworkCompression = standardUseNetworkCompression;
                setup.UseNetworkCompression = bool.TryParse(ConfigurationManager.AppSettings[SomeAcme.SomeProduct.Common.Constants.UseNetworkCompression], out useNetworkCompression) ?
                    useNetworkCompression : false;
            }

            if (ConfigurationManager.AppSettings[Constants.NetworkCompressionItemTreshold] != null)
            {
                int networkCompressionItemTreshold = standardNetworkCompressionItemTreshold;
                setup.NetworkCompressionItemTreshold = 
                    int.TryParse(ConfigurationManager.AppSettings[SomeAcme.SomeProduct.Common.Constants.NetworkCompressionItemTreshold], out networkCompressionItemTreshold) ?
                    networkCompressionItemTreshold : standardNetworkCompressionItemTreshold;
            }          

            return setup; 
        }

We also need to control the use of compression to be able to run integration tests on this:

        [ServiceLog]
        [AuthorizedRole(SomeAcmeRoles.Administrator)]
        public bool SetUseNetworkCompression(CompressionSetupDataContract compressionSetup)
        {
            if (compressionSetup == null)
                throw new ArgumentNullException(GetName.Of(() =< compressionSetup));
            try
            {
                //ADDITIONAL SECURITY CHECKS OMITTED FROM PUBLIC DISPLAY.

                System.Configuration.Configuration webConfigApp = WebConfigurationManager.OpenWebConfiguration("~");

                webConfigApp.AppSettings.Settings[Common.Constants.UseNetworkCompression].Value = compressionSetup.UseNetworkCompression
                    ? "true"
                    : "false";

                webConfigApp.AppSettings.Settings[Common.Constants.NetworkCompressionItemTreshold].Value = compressionSetup.NetworkCompressionItemTreshold.ToString(); 

                webConfigApp.Save();
                return true;
            }
            catch (Exception err)
            {
                InterfaceBinding.GetInstance().WriteError(err.Message);
                return false;
            }
        }

Okay, so no we have the necessary code to test out the compression. Of course to apply this mini-framework to your codebase, you will need to add at least the two app settings above and the two utility classes presented, plus the IDataContractContainer, an implementation of this and use the two methods on the service side and the client side of DataContractUtility noted - To note again:

On the client side, it is necessary retrieve the data. On the server side we have used the GetContainerInstanceAfterResult method, we will now use the GetDataContractsFromDataContractContainer method.



Here is an integration test for testing out our mini framework:

        [Test]
        [Category(TestCategories.IntegrationTest)]
        public void GetOperationItemsWithCompressionEnabledDoesNotReturnEmpty()
        {
            var compressionSetup = new CompressionSetupDataContract
            {
                UseNetworkCompression = true,
                NetworkCompressionItemTreshold = 1
            };

            SystemServiceAgent.SetUseNetworkCompression(compressionSetup);

            //Arrange 
            var operationsRequest = new OperationsRequestDataContract()
            {
                .. // some init
            };

            //Act 
            OperationItemsContainerDataContract operationsRetrieved = ConcreteServiceAgent.GetSomeOperationItems(operationsRequest);

            //Assert 
            Assert.IsNotNull(operationsRetrieved);

            var operationItems =
                DataContractUtility.GetDataContractsFromDataContractContainer<OperationItemsContainerDataContract, OperationItemDataContract>(
                    operationsRetrieved, compressionSetup);

            Assert.IsNotNull(operationItems);
            CollectionAssert.IsNotEmpty(operationItems);

            compressionSetup.UseNetworkCompression = false;
            compressionSetup.NetworkCompressionItemTreshold = 100;
            SystemServiceAgent.SetUseNetworkCompression(compressionSetup);
        }


Note that the compression will mean that more CPU is used on the service side. This can in many cases mean that while saving bandwidth, you spend more CPU resources on an already busy application server running the WCF services. At the same time, more and more clients of WCF services are on low-bandwidth devices, such as smart phones. I have run some tests and retrieved about 1000 items of relateively large data contracts and witnessed a saving from 5-6 MegaBytes (MB) of data being transferred down to 300 kB, a saving of about 20x! WCF serialization using SOAP XML results often in gigantic amounts of data being transferred. If you are on a Gigabit Ethernet network, it might not be noticably on a developer computer, but if you are creating a system with many simultaneous users, network bandwidth usage is starting to get really important, even on high-bandwidth clients! At the end, I must note that the code presented here is very generic. It can be used not only for sending data between WCF services to clients, but a requirement is that the data you are about to send is data contract items (classes and properties using the DataContract and DataMember attributes.

Saturday 23 May 2015

Logging WPF binding errors to the Event Log and testing this

WPF developers know that WPF binding errors are elusive. The XAML code is populated with binding expressions that are not strongly bound, but written as text. Whereas Razor MVC views support strongly bound MVC views, WPF developers must keep a keen eye to the debugger output view as their application run to spot WPF binding errors. These binding errors are not thrown and hard to track down. In addition, if there is present one WPF binding error per item in an items control, the WPF application will quickly behave slow, indicating a possible WPF binding error or several as is often the case. It is possible to track binding errors to the Event Log. In this article I will present code I use to track or trace binding errors to the Event Log and a NUnit integration test that actually test out the binding trace listener. Let's first review the code for the WPF binding error trace listener:


 public class BindingErrorTraceListener : DefaultTraceListener
    {
        private static BindingErrorTraceListener _Listener;

        public static void SetTrace()
        { SetTrace(SourceLevels.Error, TraceOptions.None); }

        public static void SetTrace(SourceLevels level, TraceOptions options)
        {
            if (_Listener == null)
            {
                _Listener = new BindingErrorTraceListener();
                PresentationTraceSources.DataBindingSource.Listeners.Add(_Listener);
            }

            _Listener.TraceOutputOptions = options;
            PresentationTraceSources.DataBindingSource.Switch.Level = level;
        }

        public static void CloseTrace()
        {
            if (_Listener == null)
            { return; }

            _Listener.Flush();
            _Listener.Close();
            PresentationTraceSources.DataBindingSource.Listeners.Remove(_Listener);
            _Listener = null;
        }



        private StringBuilder _Message = new StringBuilder();

        private BindingErrorTraceListener()
        { }

        public override void Write(string message)
        { _Message.Append(message); }

        public override void WriteLine(string message)
        {
            _Message.Append(message);

            var final = "WPF DataBinding Error detected: " +
                Environment.NewLine + _Message.ToString();
            _Message.Length = 0;

            var logger = new EventLogProvider();
            if (logger != null)
                logger.Log("OpPlan 4.x Client WPF Data Binding Error found: " + final, Microsoft.Practices.Prism.Logging.Category.Warn,
                     Microsoft.Practices.Prism.Logging.Priority.Medium);
        }
    }

The class above BindingErrorTraceListener inherits from DefaultTraceListener in the System.Diagnostics namespace. It is adding itself to the PresentationSources.DataBindingSource.Listeners property collection, this is also in the System.Diagnostics namespace. The class uses another class called EventLogProvider, which implements ILoggerFacade (used in PRISM applications), which will be presented below. Please note that the DefaultTraceListener is already added to this collection. The following code will add the implementation above to the listener collection in the OnStartup method of the WPF application class codebehind - App.xaml.cs:

public partial class App {

.. 


 protected override void OnStartup(StartupEventArgs e) {
  
  if (Debugger.IsAttached){
   BindingErrorTraceListener.SetTrace();   
  }

 }

}

Note that the listener is added to the listener collection if a debugger is attached. WPF binding errors should be detected by the developer and resolved before software is published and/or shipped to the customers. The EventLogProvider class is simply a wrapper to the EventLog class in System.Diagnostics namespace.

using System.Configuration;
using System.Diagnostics;
using Hemit.Diagnostics;
using Microsoft.Practices.Prism.Logging;

namespace Hemit.OpPlan.Client.Infrastructure
{
    //[Export(typeof(ILoggerFacade))]
    //[PartCreationPolicy(CreationPolicy.Shared)]
    public class EventLogProvider : ILoggerFacade
    {
        public const int MaxLogMessageLength = 32765;
        public string EventLogSourceName { get; set; }


        public EventLogProvider()
        {
            EventLogSourceName = ConfigurationManager.AppSettings[Constants.EventLogSourceNameMefKey];
        }

        private void WriteEntry(string message, Category category, Priority priority)
        {
            int eventID = 0;

            System.Diagnostics.EventLog.WriteEntry(EventLogSourceName, message, GetEventLogEntryType(category), eventID, GetPriorityId(priority));
        }

        private static EventLogEntryType GetEventLogEntryType(Category category)
        {
            switch (category)
            {
                case Category.Debug:
                    return EventLogEntryType.Information;
                case Category.Exception:
                    return EventLogEntryType.Error;
                case Category.Info:
                    return EventLogEntryType.Information;
                case Category.Warn:
                    return EventLogEntryType.Warning;
                default:
                    return EventLogEntryType.Error;
            }
        }

        private static short GetPriorityId(Priority priority)
        {
            switch (priority)
            {
                case Priority.None:
                    return 0;
                case Priority.High:
                    return 1;
                case Priority.Medium:
                    return 2;
                case Priority.Low:
                    return 3;
                default:
                    return 0;
            }
        }

        #region ILoggerFacade Members

        public void Log(string message, Category category, Priority priority)
        {
            WriteEntry(message, category, priority);
        }

        #endregion
    }
}

It is also nice to make an integration test to actually test that the WPF binding error trace listener actually works. I only use the integration to simulate a WPF binding error, but one can extend the EventLogProvider class to retrieve an Event Log Item that matches some certain text


    [TestFixture]
    public class BindingErrorTraceListenerUtilityTest
    {

        [Test]
        [RequiresSTA]
        public void Wpf_BindingError_Causes_Logging_Through_BindingTraceErrorListener()
        {
             
            BindingErrorTraceListener.SetTrace();
           
            var sausages = new[]
            {
                new {Product = "Hot dog", Cost = 25},
                new {Product = "Wiener", Cost = 42},
                new {Product = "Bratwurst", Cost = 53},
                new {Product = "Trailer grill shrimp salad with all extras deluxe edition", Cost = 112 }
            };

            string tbName = "tbGoof";
            var sausageBooth = new ListBox();
            var vt = new FrameworkElementFactory(typeof(TextBlock));
            vt.Name = tbName;
            vt.SetBinding(TextBlock.TextProperty, new Binding("Costt")); //Her setter man en WPF binding error 

            var groupBox = new GroupBox {DataContext = sausages[3]};
            groupBox.SetBinding(HeaderedContentControl.HeaderProperty, "Costt");

            var bindingExpressionGroupHeader = BindingOperations.GetBindingExpressionBase(groupBox, HeaderedContentControl.HeaderProperty);
            if (bindingExpressionGroupHeader != null)
                bindingExpressionGroupHeader.UpdateTarget();

            sausageBooth.ItemTemplate = new DataTemplate()
            {
                VisualTree = vt
            };
            sausageBooth.ItemsSource = sausages;

            sausageBooth.UpdateLayout();

            IItemContainerGenerator generator = sausageBooth.ItemContainerGenerator;
            GeneratorPosition position = generator.GeneratorPositionFromIndex(0);
            using (generator.StartAt(position, GeneratorDirection.Forward, true))
            {
                foreach (object o in sausageBooth.Items)
                {
                    DependencyObject dp = generator.GenerateNext();
                    generator.PrepareItemContainer(dp);
                }
            }

            ListBoxItem firstComboBoxItem = (ListBoxItem) sausageBooth.ItemContainerGenerator.ContainerFromIndex(0);
            Assert.IsNotNull(firstComboBoxItem);
            firstComboBoxItem.ContentTemplate.Seal();
            var dt = firstComboBoxItem.ContentTemplate.LoadContent();
            Assert.IsNotNull(dt);
            var textblock = dt as TextBlock;
            Assert.IsNotNull(textblock);
            textblock.SetBinding(TextBlock.TextProperty, new Binding("Costt"));
            var bindingTextBlock = BindingOperations.GetBindingExpressionBase(textblock, TextBlock.TextProperty);
            Assert.IsNotNull(bindingTextBlock);
            bindingTextBlock.UpdateTarget();

        }


Note that the integration above will run through and not assert anything at the end if the Event Log item was really added, but it will be easy to inspect the Event Log using the eventvwr command and check that the WPF binding error was really added. The BindingErrorTraceListener class adds wpf binding errors as warnings in the event log, as the following screen shot display:

Wednesday 8 April 2015

Configuring CORS in MVC Web Api 2

This article will describe how developers can handle the restrictions of the Same-Origin Policy that exists in all browsers today by relaxing these restrictions using CORS, Cross-origin resource sharing. Developers that have worked some with Javascript and Ajax calls, for example using jQuery, know about the limitations that are imposed on browsers for scripts to abide the "same-origin policy". If a script written in Javascript tries to access a remote resource on another host, or even another website and/or port on the same machine, chances are high that the Ajax call will although executed, not be returned to the caller (calling script). Typically if a developer checks the Developer Tools in a browser (usually F12), the following error pops up:

XMLHttpRequest cannot load http://localhost:58289/api/products/. No 'Access-Control-Allow-Origin' header is present on the requested resource. Origin 'http://localhost:62741' is therefore not allowed access.

This error is because the remote resource, in this case the remote service which is another ASP.NET Web API 2 site on the same machine (localhost), does not allow this origin because of the missing 'Access-Control-Allow-Origin' header in the response. The origin itself is the requesting address and will be sent in the request as the 'Origin' header. The request in this particular case was:

Accept:*/*
Accept-Encoding:gzip, deflate, sdch
Accept-Language:en-US,en;q=0.8,nb;q=0.6,sv;q=0.4
Cache-Control:max-age=0
Connection:keep-alive
Host:localhost:58289
Origin:http://localhost:62741
Referer:http://localhost:62741/index.html
User-Agent:Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.118 Safari/537.36

Note the origin request header above. I have enabled CORS on the remote site on the same machine but configured it to deny this address. First, to enable CORS in a MVC Web API 2 website/webapplication, install the following nuget package: Microsoft ASP.NET Web API 2.2 Cross-Origin Support 5.2.3

Install-Package Microsoft.AspNet.WebApi.Cors

Next up, edit the class WebApiConfig in the App_Start folder:

   public static void Register(HttpConfiguration config)
   {

            config.EnableCors(); 
            // Web API configuration and services

            // Web API routes
            config.MapHttpAttributeRoutes();

            config.Routes.MapHttpRoute(
                name: "DefaultApi",
                routeTemplate: "api/{controller}/{id}",
                defaults: new { id = RouteParameter.Optional }
            );
  }

The line that was added here in addition to the default setup is the line:
config.EnableCors(); 
It is possible to allow all calls to this remote site by setting up the web.config file with the following in <system.webServer><httpProtocol>:

    <customHeaders>
      <add name="Access-Control-Allow-Origin" value="*" />
    </customHeaders>

However, this will actually not always be desired as often one knows which origin that will call this remote resource (site) and it is possible to do this programatically. So for now, commenting out this line in web.config and in the following example we use the [EnableCors] attribute:

using ProductsApp.Models;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Web.Http;
using System.Web.Http.Cors;
using System.Web.Http.Description;

namespace ProductsApp.Controllers
{

    [EnableCors(origins: "http://localhost:62742", headers: "*", methods: "*")]
    public class ProductsController : ApiController
    {

        private readonly Product[] products = new Product[]{
            new Product { Id = 1, Name = "Tomato Soup", Category = "Groceries", Price = 1 },
            new Product { Id = 2, Name = "Yo-yo", Category = "Toys", Price = 3.75M },
            new Product { Id = 3, Name = "Hammer", Category = "Hardware", Price = 16.99M }
        };

        public IEnumerable<Product> GetAllProducts()
        {
            return products; 
        }

      
        public async Task<IHttpActionResult> GetProduct(int id)
        {
            Product product = await GetProductAsync(id);
            if (product == null)
            {
                return NotFound(); 
            }
            return Ok(product); 
        }

        private async Task GetProductAsync(int id)
        {
            await Task.Delay(1000); 
            var product = products.FirstOrDefault(p => p.Id == id);
            return product;
        }      

    }

}

The following CORS setup was applied to the entire Web API v2 controller, which inherits from API Controller:
[EnableCors(origins: "http://localhost:62742", headers: "*", methods: "*")]
Here, the orgins property can list up (comma-separated) the allowed origins. Wildcards can be used here. The headers and methods configure further the setup of CORS. The [EnableCors] attribute can be applied to a controller or to a method inside the Web API v2 controller for fine-grained control. It is possible to do custom policy of CORS by using instead of [EnableCors] attribute a custom class that implements the interface
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method,
                AllowMultiple = false)]
public class EnableCorsForPaidCustomersAttribute :
  Attribute, ICorsPolicyProvider
{
  public async Task GetCorsPolicyAsync(
    HttpRequestMessage request, CancellationToken cancellationToken)
  {
    var corsRequestContext = request.GetCorsRequestContext();
    var originRequested = corsRequestContext.Origin;
    if (await IsOriginFromAPaidCustomer(originRequested))
    {
      // Grant CORS request
      var policy = new CorsPolicy
      {
        AllowAnyHeader = true,
        AllowAnyMethod = true,
      };
      policy.Origins.Add(originRequested);
      return policy;
    }
    else
    {
      // Reject CORS request
      return null;
    }
  }
  private async Task IsOriginFromAPaidCustomer(
    string originRequested)
  {
    // Do database look up here to determine if origin should be allowed
    return true;
  }
}
Note here that we can do a database lookup inside the method IsOriginFromAPaidCustomer above. The class inherits from attribute, so one will use this attribute instead. We can add properties for further customization. Using a database-lookup will let us have control to which callers we allow to do a CORS call from their (Java)scripts without having to recompile, since the EnableCors attribute specifies the allowed origins hardcoded. There are further details of CORS, but this is really what is required to get successful ajax calls like the one below to work and avoiding the CORS violation.

    <script>

        $(document).ready(function () {

            var uri = 'http://localhost:58289/api/products/'; 

            $.ajax({
                url: uri,
                success: function (data) {
                    $.each(data, function (key, item) {
                        $('<li>', { text: formatItem(item) }).appendTo($('#products'));
                    });
                }
            });

            function formatItem(item) {
                return item.Name + ': $' + item.Price + ' ID: ' + item.Id;
            }

        });

    </script>

If it is desired to not use CORS (Cross Origin Resource-Sharing) for a given method or controller, use the [DisableCors] attribute. It is also possible to set up [EnableCors] with setting methods = "", i.e. no methods (HTTP Verbs) should be allowed CORS.

Saturday 14 March 2015

Resharper inspecting consistent locking of objects

Resharper has got a great inspection of your code of many different aspects, such as consistent locking of objects. An example is shown below, I have been using Resharper 9.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace InconsistentLocking
{
    class Program
    {

        private static readonly object Locker = new object();


        private static readonly List<object> fooness = new List<object>(); 

        static void Main(string[] args)
        {

            DoSomeFoo(); 

        }

        private static void DoSomeFoo()
        {
            lock (Locker)
            {
                fooness.Add("Foo");
            }

            lock (Locker)
            {
                fooness.Remove("Foo");
            }

            lock (Locker)
            {
                fooness.Add("Baze");
            }

            //Resharper warning next - The field is sometimes used inside synchronized block and sometimes without synchronization 

            fooness.Remove("Baze");

            lock (Locker)
            {
                foreach (var foo in fooness)
                {
                    Debug.WriteLine(foo.ToString());
                }
            }


        }

    }
}

The code shows a simple example of using the list called fooness. The warning shows that for the line fooness.Remove("Baze"); that one is not locking the resource, in this case a list in memory called fooness. To fix the warning, just hit Alt+Enter Resharper Shortcut and the code will surround the resource with the lock block again.

Tuesday 10 February 2015

Automatic mapping for deep object graphs in Entity Framework



Using the great code sample from DevTrends article "Stop using AutoMapper in your source code", I have adjusted the code a bit to support not only up to two layers deep object graphs, but arbitrarily deepth. The limit is a design choice which can be easily increased into deeper levels. The code will automatic generate the mapping code using Expression Trees. Example:


.New DevTrends.QueryableExtensionsExample.StudentSummary(){
    FirstName = $src.FirstName,
    LastName = $src.LastName,
    TutorName = ($src.Tutor).Name,
    TutorAddressStreet = (($src.Tutor).Address).Street,
    TutorAddressMailboxNumber = ((($src.Tutor).Address).Mailbox).Number,
    CoursesCount = ($src.Courses).Count
}


Consider the automatic mapping of an entity in Entity Framework to a flattened model, where the entity is a relatively deep object graph. The tedious mapping code will be boring to write when the number properties grow. If we stick to a convention where we use CamelCasing to denote levels in the object graph, we can address the subproperties arbitrarily deep with this code.

   public class StudentSummary
    {

        public string FirstName { get; set; }

        public string LastName { get; set; }

        public string NotADatabaseColumn { get; set; }

        public string TutorName { get; set; }

        public string TutorAddressStreet { get; set;  }

        public string TutorAddressMailboxNumber { get; set; }

        public int CoursesCount { get; set; }

        public string FullName
        {
            get { return string.Format("{0} {1}", FirstName, LastName); }
        }

    }

Note here that for a given property such as TutorAddressMailboxNumber, the code will look for the property Tutor.Address.Mailbox.Number since the CamelCasing convention used here will check this. The implementation of the Project() method and the To() method on IQueryable<T> is shown next. The code uses Expression Trees to generate the required automatic mapping code. Note that the method Buildbinding method will run when the Expression.Lambda statement is executed, when debugging.


using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;

namespace DevTrends.QueryableExtensionsExample
{

    public static class QueryableExtensions
    {
        public static ProjectionExpression<TSource> Project<TSource>(this IQueryable<TSource> source)
        {
            return new ProjectionExpression<TSource>(source);
        }
    }

    public class ProjectionExpression<TSource>
    {
        private static readonly Dictionary<string, Expression> _expressionCache = new Dictionary<string, Expression>();

        private readonly IQueryable<TSource> _source;

        public ProjectionExpression(IQueryable<TSource> source)
        {
            _source = source;
        }

        public IQueryable<TDest> To<TDest>()
        {
          var queryExpression = GetCachedExpression<TDest>() ?? BuildExpression<TDest>();

            return _source.Select(queryExpression);
        }        

        private static Expression<Func<TSource, TDest>> GetCachedExpression<TDest>()
        {
            var key = GetCacheKey<TDest>();

            return _expressionCache.ContainsKey(key) ? _expressionCache[key] as Expression<Func<TSource, TDest>> : null;
        }

        private static Expression<Func<TSource, TDest>> BuildExpression<TDest>()
        {
            var sourceProperties = typeof(TSource).GetProperties();
            var destinationProperties = typeof(TDest).GetProperties().Where(dest => dest.CanWrite);
            var parameterExpression = Expression.Parameter(typeof(TSource), "src");
            
            var bindings = destinationProperties
                                .Select(destinationProperty => BuildBinding(parameterExpression, destinationProperty, sourceProperties))
                                .Where(binding => binding != null);

            var expression = Expression.Lambda<Func<TSource, TDest>>(Expression.MemberInit(Expression.New(typeof(TDest)), bindings), parameterExpression);

            var key = GetCacheKey<TDest>();

            _expressionCache.Add(key, expression);

            return expression;
        }        

        private static MemberAssignment BuildBinding(Expression parameterExpression, MemberInfo destinationProperty, IEnumerable<PropertyInfo> sourceProperties)
        {
            var sourceProperty = sourceProperties.FirstOrDefault(src => src.Name == destinationProperty.Name);

            if (sourceProperty != null)
            {
                return Expression.Bind(destinationProperty, Expression.Property(parameterExpression, sourceProperty));
            }

            var propertyNameComponents = SplitCamelCase(destinationProperty.Name);

            if (propertyNameComponents.Length >= 2)
            {
                sourceProperty = sourceProperties.FirstOrDefault(src => src.Name == propertyNameComponents[0]);
                if (sourceProperty == null)
                    return null; 

                var propertyPath = new List<PropertyInfo> { sourceProperty };
                TraversePropertyPath(propertyPath, propertyNameComponents, sourceProperty);

                if (propertyPath.Count != propertyNameComponents.Length)
                    return null; //must be able to identify the path 

                 MemberExpression compoundExpression = null;
                
                for (int i = 0; i < propertyPath.Count; i++)
                {
                    compoundExpression = i == 0 ? Expression.Property(parameterExpression, propertyPath[0]) : 
                        Expression.Property(compoundExpression, propertyPath[i]); 
                }

                return compoundExpression != null ? Expression.Bind(destinationProperty, compoundExpression) : null; 
            }

            return null;
        }

        private static List<PropertyInfo> TraversePropertyPath(List<PropertyInfo> propertyPath, string[] propertyNames, 
            PropertyInfo currentPropertyInfo, int currentDepth = 1)
        {
            if (currentDepth >= propertyNames.Count() || currentPropertyInfo == null)
                return propertyPath; //do not go deeper into the object graph

            PropertyInfo subPropertyInfo = currentPropertyInfo.PropertyType.GetProperties().FirstOrDefault(src => src.Name == propertyNames[currentDepth]); 
            if (subPropertyInfo == null)
                return null; //The property to look for was not found at a given depth 

            propertyPath.Add(subPropertyInfo);

            return TraversePropertyPath(propertyPath, propertyNames, subPropertyInfo, ++currentDepth);
        }

        private static string GetCacheKey<TDest>()
        {
            return string.Concat(typeof(TSource).FullName, typeof(TDest).FullName);
        }

        private static string[] SplitCamelCase(string input)
        {
            return Regex.Replace(input, "([A-Z])", " $1", RegexOptions.Compiled).Trim().Split(' ');
        }

    }    
}

The database is initialized with some sample data, the sample uses an Entity Framework CodeFirst database.

using System;
using System.Data.Entity;
using System.Linq;

namespace DevTrends.QueryableExtensionsExample
{
    class Program
    {
        static void Main()
        {
            Database.SetInitializer(new TestDataContextInitializer());

            using (var context = new StudentContext())
            {
                ExampleOne(context);

                ExampleTwo(context);                
            }

            Console.Write("\nPress a key to exit: ");
            Console.Read();
        }

        private static void ExampleOne(StudentContext context)
        {
            // This is the kind of projection code that we are replacing

            //var students = from s in context.Students
            //               select new StudentSummary
            //                          {
            //                              FirstName = s.FirstName,
            //                              LastName = s.LastName,
            //                              TutorName = s.Tutor.Name,
            //                              CoursesCount = s.Courses.Count
            //                          };


            // The line below performs exactly the same query as the code above.

            var students = context.Students.Project().To();

            Console.Write("Example One\n\n");

            // Uncomment this to see the generated SQL. The extension method produces the same SQL as the projection.
            //Console.WriteLine("SQL: \n\n{0}\n\n", students);

            foreach (var student in students)
            {
                Console.WriteLine("{0} is tutored by {1} in {2} subject(s).", student.FullName, student.TutorName, student.CoursesCount);
            }
        }

        private static void ExampleTwo(StudentContext context)
        {
            //var students = from s in context.Students
            //               select new AnotherStudentSummary()
            //               {
            //                   FirstName = s.FirstName,
            //                   LastName = s.LastName,
            //                   Courses = s.Courses
            //               };

            var students = context.Students.Project().To();

            Console.WriteLine("\nExample Two\n");
            
            //Console.WriteLine("SQL: \n\n{0}\n\n", students);

            foreach (var student in students)
            {
                var coursesString = string.Join(", ", student.Courses.Select(s => s.Description).ToArray());
                Console.WriteLine("{0} {1} takes the following courses: {2}", student.FirstName, student.LastName, coursesString);
            }
        }
    }        
}



using System;
using System.Data.Entity;
using System.Linq;

namespace DevTrends.QueryableExtensionsExample
{
    class Program
    {
        static void Main()
        {
            Database.SetInitializer(new TestDataContextInitializer());

            using (var context = new StudentContext())
            {
                ExampleOne(context);

                ExampleTwo(context);                
            }

            Console.Write("\nPress a key to exit: ");
            Console.Read();
        }

        private static void ExampleOne(StudentContext context)
        {
            // This is the kind of projection code that we are replacing

            //var students = from s in context.Students
            //               select new StudentSummary
            //                          {
            //                              FirstName = s.FirstName,
            //                              LastName = s.LastName,
            //                              TutorName = s.Tutor.Name,
            //                              CoursesCount = s.Courses.Count
            //                          };


            // The line below performs exactly the same query as the code above.

            var students = context.Students.Project().To<StudentSummary>();

            Console.Write("Example One\n\n");

            // Uncomment this to see the generated SQL. The extension method produces the same SQL as the projection.
            //Console.WriteLine("SQL: \n\n{0}\n\n", students);

            foreach (var student in students)
            {
                Console.WriteLine("{0} is tutored by {1} in {2} subject(s).", student.FullName, student.TutorName, student.CoursesCount);
            }
        }

        private static void ExampleTwo(StudentContext context)
        {
            //var students = from s in context.Students
            //               select new AnotherStudentSummary()
            //               {
            //                   FirstName = s.FirstName,
            //                   LastName = s.LastName,
            //                   Courses = s.Courses
            //               };

            var students = context.Students.Project().To<AnotherStudentSummary>();

            Console.WriteLine("\nExample Two\n");
            
            //Console.WriteLine("SQL: \n\n{0}\n\n", students);

            foreach (var student in students)
            {
                var coursesString = string.Join(", ", student.Courses.Select(s => s.Description).ToArray());
                Console.WriteLine("{0} {1} takes the following courses: {2}", student.FirstName, student.LastName, coursesString);
            }
        }
    }        
}


I have tested the code with a depth of four for the object graph, which will be sufficient in most cases. The maximum level is 10 just to prevent the recursion for giving a possible stack overflow, but the limit can be easily adjusted here. Using the code provided here, it will be easier to automatic map entity framework code into models. The code provided here should be used for retrieving data, such as selects. Note that we are short-circuiting the change tracking of Entity Framework here and must stick to a CamelCasing convention to get the automatic mapping from Entity to the flattened model, that must keep to the CamelCasing convention. .

Download the sample code here

Source code in Visual Studio 2012 solution

Example integration test (The ObjectContextManager.ScopedOpPlanDataContext object here is an example ObjectContext)


        [Test]
        [Category(TestCategories.IntegrationTest)]
        public void ProjectToPerformsExpected()
        {
            using (var context = ObjectContextManager.ScopedOpPlanDataContext)
            {
                var globalParameters = context.GlobalParameters.Project().To<GlobalParametersDataContract>().ToList();
                Assert.IsNotNull(globalParameters);
                CollectionAssert.IsNotEmpty(globalParameters);
            }
        }