Monday, October 24, 2011

LINQ to Dynamics AX - Data Access Layer (2)

Like the title says, this is how we are going to make connecting to Dynamics AX, via the business connector, generic.

First off, lets make a .NET 4.0 class library project called LINQTest.

We need to include a reference to the Microsoft.Dynamics.BusinessConnectorNet, which is a 2.0 component, so lets add an Application Configuration file to our new project, and add the following to it, to allow us to do so.

<?xml version="1.0"?>
<configuration>
  <startup useLegacyV2RuntimeActivationPolicy="true">
    <supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.0"/>
  </startup>
 </configuration>

Next we want to create a new class named DAXDataAccessLayer. The purpose of this class is to provide generic access to the business connector. It uses a config file, supplied by the consuming application or service, and has dependency on  1 additional class, DAXUtils.

using System;
using System.Configuration;
using Microsoft.Dynamics.BusinessConnectorNet;
using System.Reflection;
using System.Data;
using System.Diagnostics;


namespace LINQTest
{
    public class DAXDataAccessLayer
    {
        private static Axapta staticAxapta;
        private static object lockAxObject = new object();

        static DAXDataAccessLayer()
        {
            staticAxapta = new Axapta();
            LOGON(staticAxapta);
        }

        public static string GetAxConfig()
        {
            try
            {
                return ConfigurationManager.AppSettings["AxConfigFile"];
            }
            catch (Exception e)
            {
                DAXUtils.LogEvent("DAXDataAccessLayer", "Failed to retrieve Dynamics Configuration! " + e.Message, EventLogEntryType.Error);
            }
            return null;
        }

        #region Log In Functionality
        // Method to Log in the integration
        public static void LOGON(Axapta axapta)
        {
            try
            {
                string AuthUser = ConfigurationManager.AppSettings["AxBCUser"];
                string AuthPassword = ConfigurationManager.AppSettings["AxBCPassword"];
                string AuthDomain = ConfigurationManager.AppSettings["AxBCDomain"];
                System.Net.NetworkCredential cred = new System.Net.NetworkCredential(AuthUser, AuthPassword, AuthDomain);
                axapta.LogonAs(AuthUser, AuthDomain, cred, "ldc", "en-us", "", GetAxConfig());
            }
            catch (Exception e)
            {
                DAXUtils.LogEvent("DAXDataAccessLayer", "Business Connector Log In Failed! " + e.Message, EventLogEntryType.Error);
                throw;
            }
        }

        // cleans up and resets the login to Dynamics
        private void ResetStaticLogon()
        {
            DAXUtils.LogEvent("DAXDataAccessLayer", "ResetStaticLogon Called! ", EventLogEntryType.Error);

            lock (lockAxObject)
            {
                try
                {
                    staticAxapta.Logoff();
                    staticAxapta.Dispose();
                }
                catch (Exception)
                { }

                staticAxapta = new Axapta();
                LOGON(staticAxapta);
            }
        }
        #endregion

       
        public DataTable AxExecQuery(string query, Type objectType)
        {
            AxaptaRecord ar;
            DataTable dt = new DataTable();
            // create and instance of our object class so we can run a generic method on it later via reflection
            object instance = Activator.CreateInstance(objectType);
          
            lock (lockAxObject)
            {
                try
                {
                    ar = staticAxapta.CreateAxaptaRecord(objectType.Name);
                    try
                    {
                        ar.ExecuteStmt(query);
                        // Call interface to set properties
                        dt = (DataTable)objectType.InvokeMember("MakeDataTable", BindingFlags.InvokeMethod, null, instance, new object[] { ar });
                        return dt;
                    }
                    finally
                    {
                        dt.Dispose();
                    }
                }
                catch (Exception e)
                {
                    DAXUtils.LogEvent("DAXDataAccessLayer", "Business Connector execQuery Failed! " + e.Message, EventLogEntryType.Error);
                    ResetStaticLogon();
                    throw e;
                }
            }
        }
    }
}

So far so good. You probably have some build errors, so lets add the other class, DAXUtils for logging.

Here's DAXUtils...

using System;
using System.Diagnostics;

namespace LINQTest
{
    class DAXUtils
    {
        public static void LogEvent(string appName, string logMessage, EventLogEntryType eventLogType)
        {
            EventLog el = new EventLog();
            el.Source = appName;
            el.WriteEntry(logMessage, eventLogType);
        }
    }
}

Another utility class called TypeSystem is required, so let's add it. Create a class called TypeSystem, and add the following code to it. This is a generic, reusable class.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace LINQTest
{
    internal static class TypeSystem
    {
        internal static Type GetElementType(Type seqType)
        {
            Type ienum = FindIEnumerable(seqType);
            if (ienum == null) return seqType;
            return ienum.GetGenericArguments()[0];
        }
        private static Type FindIEnumerable(Type seqType)
        {
            if (seqType == null || seqType == typeof(string))
                return null;
            if (seqType.IsArray)
                return typeof(IEnumerable<>).MakeGenericType(seqType.GetElementType());
            if (seqType.IsGenericType)
            {
                foreach (Type arg in seqType.GetGenericArguments())
                {
                    Type ienum = typeof(IEnumerable<>).MakeGenericType(arg);
                    if (ienum.IsAssignableFrom(seqType))
                    {
                        return ienum;
                    }
                }
            }
            Type[] ifaces = seqType.GetInterfaces();
            if (ifaces != null && ifaces.Length > 0)
            {
                foreach (Type iface in ifaces)
                {
                    Type ienum = FindIEnumerable(iface);
                    if (ienum != null) return ienum;
                }
            }
            if (seqType.BaseType != null && seqType.BaseType != typeof(object))
            {
                return FindIEnumerable(seqType.BaseType);
            }
            return null;
        }
    }

}

You will notice that I am using reflection in the AXExecQuery method to create an instance of the calling TYPE, so that I can run a method that returns a DataTable from it. Crazy huh? That is due to not having an IEnumerator implementation on the AxaptaRecord object in the business connector, which I need to have later. Not unlike my previous architectures with dynamics implementation, I use the DataTable and DataReaders to do a lot of work the AxaptaRecord object cannot. It's not cheating, it's overcoming.

Well, there we have it. The foundational layer of our new Linq to Dynamics AX provider. I may change some things here, but for now I like it. Notice I have not implemented Update or Delete yet, those are in progress. Only WHERE of the expression tree is implemented, so only select, read only statements are supported.

Let's move on to the QueryProvider.  This will consume our DataAccessLayer and also the catalyst for our modified query translator, because we know how X++ queries are so much like TSQL.. Ah well, bad joke. Nobody is reading this for my humor I'll bet, so let's keep going.

Add a new class to your project, name it DAXQueryProvider, and add the following code..

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Data.Common;
using System.Reflection;
using System.Data;

namespace LINQTest
{
    public class DAXQueryProvider : QueryProvider
    {
        DAXDataAccessLayer connection;

        public DAXQueryProvider(DAXDataAccessLayer connection)
         {
            this.connection = connection;
         }


        public override string GetQueryText(Expression expression)
        {
            return this.Translate(expression).CommandText;
        }
        public override object Execute(Expression expression)
        {
            TranslateResult result = this.Translate(expression);
            DataTable dt = new DataTable();
            string query = result.CommandText;
            Type elementType = TypeSystem.GetElementType(expression.Type);

            dt = connection.AxExecQuery(query, elementType);
            if (result.Projector != null)
            {
                Delegate projector = result.Projector.Compile();
                return Activator.CreateInstance(typeof(ProjectionReader<>).MakeGenericType(elementType), BindingFlags.Instance | BindingFlags.NonPublic, null, new object[] { dt.CreateDataReader(), projector },null);
            }
            else
            {
                return Activator.CreateInstance(typeof(ObjectReader<>).MakeGenericType(elementType), BindingFlags.Instance | BindingFlags.NonPublic, null, new object[] { dt.CreateDataReader() }, null);
            }

        }

        private TranslateResult Translate(Expression expression)
        {
            expression = Evaluator.PartialEval(expression);
            return new QueryTranslator().Translate(expression);
        }

    }
}
This class implements QueryProvider, and references few Object Creation classes designed to populate our generic class objects. It also calls a Translator class to iterate the expression tree, and make our query. Note, I did NOT write these classes, only modified them to translate Linq queries to X++ queries. Being that this is only a proof of concept project, I am reusing a lot of what is out there by some very talented folks. As my provider matures, they will most likely be updated, or replace completely. For now, they stay.
 
Now create another class, and call this one QueryTranslator. As I mentioned, this class was written by someone else, and I merely modified Expression methods for what I needed to do, as far as X++. Also, you will see a partial implementation of the SELECT expression. Ignore it, it's not completed yet and shouldn't bother your project.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;

namespace LINQTest
{
    internal class TranslateResult
    {
        internal string CommandText;
        internal LambdaExpression Projector;
    }

    internal class QueryTranslator : ExpressionVisitor
    {
        StringBuilder sb;
        ParameterExpression row;
        ColumnProjection projection;

        internal QueryTranslator()
        {
        }

      
        // added this stub for select
        internal TranslateResult Translate(Expression expression)
        {
            this.sb = new StringBuilder();
            this.row = Expression.Parameter(typeof(ProjectionRow), "row");
            this.Visit(expression);
            return new TranslateResult
            {
                CommandText = this.sb.ToString(),
                Projector = this.projection != null ? Expression.Lambda(this.projection.Selector, this.row) : null
            };
        }



        private static Expression StripQuotes(Expression e)
        {
            while (e.NodeType == ExpressionType.Quote)
            {
                e = ((UnaryExpression)e).Operand;
            }
            return e;
        }


        protected override Expression VisitMethodCall(MethodCallExpression m)
        {
            if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where")
            {
                this.Visit(m.Arguments[0]);
                sb.Append(" WHERE ");
                LambdaExpression lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
                this.Visit(lambda.Body);
                return m;
            }
            else if (m.Method.Name == "Select")
            {
                LambdaExpression lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
                ColumnProjection projection = new ColumnProjector().ProjectColumns(lambda.Body, this.row);
                sb.Append("SELECT ");
                sb.Append(projection.Columns);
                sb.Append(" FROM (");
                this.Visit(m.Arguments[0]);
                sb.Append(") AS T ");
                this.projection = projection;
                return m;
            }
            throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));
        }


        protected override Expression VisitUnary(UnaryExpression u)
        {
            switch (u.NodeType)
            {
                case ExpressionType.Not:
                    sb.Append(" NOT ");
                    this.Visit(u.Operand);
                    break;
                default:
                    throw new NotSupportedException(string.Format("The unary operator '{0}' is not supported", u.NodeType));
            }
            return u;
        }


        protected override Expression VisitBinary(BinaryExpression b)
        {
            //sb.Append("(");
            this.Visit(b.Left);
            switch (b.NodeType)
            {
                case ExpressionType.And:
                    sb.Append(" AND ");
                    break;
                case ExpressionType.AndAlso:
                    sb.Append(" && ");
                    break;
                case ExpressionType.Or:
                    sb.Append(" || ");
                    break;
                case ExpressionType.OrElse:
                    sb.Append(" || ");
                    break;
                case ExpressionType.Equal:
                    sb.Append(" == ");
                    break;
                case ExpressionType.NotEqual:
                    sb.Append(" <> ");
                    break;
                case ExpressionType.LessThan:
                    sb.Append(" < ");
                    break;
                case ExpressionType.LessThanOrEqual:
                    sb.Append(" <= ");
                    break;
                case ExpressionType.GreaterThan:
                    sb.Append(" > ");
                    break;
                case ExpressionType.GreaterThanOrEqual:
                    sb.Append(" >= ");
                    break;
                default:
                    throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));
            }
            this.Visit(b.Right);
            //sb.Append(")");
            return b;
        }


        protected override Expression VisitConstant(ConstantExpression c)
        {
            IQueryable q = c.Value as IQueryable;
            if (q != null)
            {
                // assume constant nodes w/ IQueryables are table references
                sb.Append("SELECT * FROM %1");
            }
            else if (c.Value == null)
            {
                sb.Append("NULL");
            }
            else
            {
                switch (Type.GetTypeCode(c.Value.GetType()))
                {
                    case TypeCode.Boolean:
                        sb.Append(((bool)c.Value) ? 1 : 0);
                        break;
                    case TypeCode.String:
                        sb.Append("'");
                        sb.Append(c.Value);
                        sb.Append("'");
                        break;
                    case TypeCode.Object:
                        throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", c.Value));
                    default:
                        sb.Append(c.Value);
                        break;
                }
            }
            return c;
        }


        protected override Expression VisitMemberAccess(MemberExpression m)
        {
            if (m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter)
            {
                sb.Append("%1." + m.Member.Name);
                return m;
            }
            throw new NotSupportedException(string.Format("The member '{0}' is not supported", m.Member.Name));
        }
    }
}
 

Now create a class called  ExpressionVisitor, and add the following code.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Collections.ObjectModel;

namespace LINQTest
{
    public abstract class ExpressionVisitor
    {
        protected ExpressionVisitor()
        {
        }


        protected virtual Expression Visit(Expression exp)
        {
            if (exp == null)
                return exp;
            switch (exp.NodeType)
            {
                case ExpressionType.Negate:
                case ExpressionType.NegateChecked:
                case ExpressionType.Not:
                case ExpressionType.Convert:
                case ExpressionType.ConvertChecked:
                case ExpressionType.ArrayLength:
                case ExpressionType.Quote:
                case ExpressionType.TypeAs:
                    return this.VisitUnary((UnaryExpression)exp);
                case ExpressionType.Add:
                case ExpressionType.AddChecked:
                case ExpressionType.Subtract:
                case ExpressionType.SubtractChecked:
                case ExpressionType.Multiply:
                case ExpressionType.MultiplyChecked:
                case ExpressionType.Divide:
                case ExpressionType.Modulo:
                case ExpressionType.And:
                case ExpressionType.AndAlso:
                    return this.VisitBinary((BinaryExpression)exp);
                case ExpressionType.Or:
                case ExpressionType.OrElse:
                    return this.VisitBinary((BinaryExpression)exp);
                case ExpressionType.LessThan:
                case ExpressionType.LessThanOrEqual:
                case ExpressionType.GreaterThan:
                case ExpressionType.GreaterThanOrEqual:
                case ExpressionType.Equal:
                case ExpressionType.NotEqual:
                case ExpressionType.Coalesce:
                case ExpressionType.ArrayIndex:
                case ExpressionType.RightShift:
                case ExpressionType.LeftShift:
                case ExpressionType.ExclusiveOr:
                    return this.VisitBinary((BinaryExpression)exp);
                case ExpressionType.TypeIs:
                    return this.VisitTypeIs((TypeBinaryExpression)exp);
                case ExpressionType.Conditional:
                    return this.VisitConditional((ConditionalExpression)exp);
                case ExpressionType.Constant:
                    return this.VisitConstant((ConstantExpression)exp);
                case ExpressionType.Parameter:
                    return this.VisitParameter((ParameterExpression)exp);
                case ExpressionType.MemberAccess:
                    return this.VisitMemberAccess((MemberExpression)exp);
                case ExpressionType.Call:
                    return this.VisitMethodCall((MethodCallExpression)exp);
                case ExpressionType.Lambda:
                    return this.VisitLambda((LambdaExpression)exp);
                case ExpressionType.New:
                    return this.VisitNew((NewExpression)exp);
                case ExpressionType.NewArrayInit:
                case ExpressionType.NewArrayBounds:
                    return this.VisitNewArray((NewArrayExpression)exp);
                case ExpressionType.Invoke:
                    return this.VisitInvocation((InvocationExpression)exp);
                case ExpressionType.MemberInit:
                    return this.VisitMemberInit((MemberInitExpression)exp);
                case ExpressionType.ListInit:
                    return this.VisitListInit((ListInitExpression)exp);
                default:
                    throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
            }
        }


        protected virtual MemberBinding VisitBinding(MemberBinding binding)
        {
            switch (binding.BindingType)
            {
                case MemberBindingType.Assignment:
                    return this.VisitMemberAssignment((MemberAssignment)binding);
                case MemberBindingType.MemberBinding:
                    return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
                case MemberBindingType.ListBinding:
                    return this.VisitMemberListBinding((MemberListBinding)binding);
                default:
                    throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
            }
        }


        protected virtual ElementInit VisitElementInitializer(ElementInit initializer)
        {
            ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
            if (arguments != initializer.Arguments)
            {
                return Expression.ElementInit(initializer.AddMethod, arguments);
            }
            return initializer;
        }


        protected virtual Expression VisitUnary(UnaryExpression u)
        {
            Expression operand = this.Visit(u.Operand);
            if (operand != u.Operand)
            {
                return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
            }
            return u;
        }


        protected virtual Expression VisitBinary(BinaryExpression b)
        {
            Expression left = this.Visit(b.Left);
            Expression right = this.Visit(b.Right);
            Expression conversion = this.Visit(b.Conversion);
            if (left != b.Left || right != b.Right || conversion != b.Conversion)
            {
                if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)
                    return Expression.Coalesce(left, right, conversion as LambdaExpression);
                else
                    return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
            }
            return b;
        }


        protected virtual Expression VisitTypeIs(TypeBinaryExpression b)
        {
            Expression expr = this.Visit(b.Expression);
            if (expr != b.Expression)
            {
                return Expression.TypeIs(expr, b.TypeOperand);
            }
            return b;
        }


        protected virtual Expression VisitConstant(ConstantExpression c)
        {
            return c;
        }


        protected virtual Expression VisitConditional(ConditionalExpression c)
        {
            Expression test = this.Visit(c.Test);
            Expression ifTrue = this.Visit(c.IfTrue);
            Expression ifFalse = this.Visit(c.IfFalse);
            if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse)
            {
                return Expression.Condition(test, ifTrue, ifFalse);
            }
            return c;
        }


        protected virtual Expression VisitParameter(ParameterExpression p)
        {
            return p;
        }


        protected virtual Expression VisitMemberAccess(MemberExpression m)
        {
            Expression exp = this.Visit(m.Expression);
            if (exp != m.Expression)
            {
                return Expression.MakeMemberAccess(exp, m.Member);
            }
            return m;
        }


        protected virtual Expression VisitMethodCall(MethodCallExpression m)
        {
            Expression obj = this.Visit(m.Object);
            IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
            if (obj != m.Object || args != m.Arguments)
            {
                return Expression.Call(obj, m.Method, args);
            }
            return m;
        }


        protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original)
        {
            List<Expression> list = null;
            for (int i = 0, n = original.Count; i < n; i++)
            {
                Expression p = this.Visit(original[i]);
                if (list != null)
                {
                    list.Add(p);
                }
                else if (p != original[i])
                {
                    list = new List<Expression>(n);
                    for (int j = 0; j < i; j++)
                    {
                        list.Add(original[j]);
                    }
                    list.Add(p);
                }
            }
            if (list != null)
            {
                return list.AsReadOnly();
            }
            return original;
        }


        protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment)
        {
            Expression e = this.Visit(assignment.Expression);
            if (e != assignment.Expression)
            {
                return Expression.Bind(assignment.Member, e);
            }
            return assignment;
        }


        protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding)
        {
            IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
            if (bindings != binding.Bindings)
            {
                return Expression.MemberBind(binding.Member, bindings);
            }
            return binding;
        }


        protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding)
        {
            IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
            if (initializers != binding.Initializers)
            {
                return Expression.ListBind(binding.Member, initializers);
            }
            return binding;
        }


        protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original)
        {
            List<MemberBinding> list = null;
            for (int i = 0, n = original.Count; i < n; i++)
            {
                MemberBinding b = this.VisitBinding(original[i]);
                if (list != null)
                {
                    list.Add(b);
                }
                else if (b != original[i])
                {
                    list = new List<MemberBinding>(n);
                    for (int j = 0; j < i; j++)
                    {
                        list.Add(original[j]);
                    }
                    list.Add(b);
                }
            }
            if (list != null)
                return list;
            return original;
        }


        protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original)
        {
            List<ElementInit> list = null;
            for (int i = 0, n = original.Count; i < n; i++)
            {
                ElementInit init = this.VisitElementInitializer(original[i]);
                if (list != null)
                {
                    list.Add(init);
                }
                else if (init != original[i])
                {
                    list = new List<ElementInit>(n);
                    for (int j = 0; j < i; j++)
                    {
                        list.Add(original[j]);
                    }
                    list.Add(init);
                }
            }
            if (list != null)
                return list;
            return original;
        }


        protected virtual Expression VisitLambda(LambdaExpression lambda)
        {
            Expression body = this.Visit(lambda.Body);
            if (body != lambda.Body)
            {
                return Expression.Lambda(lambda.Type, body, lambda.Parameters);
            }
            return lambda;
        }


        protected virtual NewExpression VisitNew(NewExpression nex)
        {
            IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
            if (args != nex.Arguments)
            {
                if (nex.Members != null)
                    return Expression.New(nex.Constructor, args, nex.Members);
                else
                    return Expression.New(nex.Constructor, args);
            }
            return nex;
        }


        protected virtual Expression VisitMemberInit(MemberInitExpression init)
        {
            NewExpression n = this.VisitNew(init.NewExpression);
            IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
            if (n != init.NewExpression || bindings != init.Bindings)
            {
                return Expression.MemberInit(n, bindings);
            }
            return init;
        }


        protected virtual Expression VisitListInit(ListInitExpression init)
        {
            NewExpression n = this.VisitNew(init.NewExpression);
            IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
            if (n != init.NewExpression || initializers != init.Initializers)
            {
                return Expression.ListInit(n, initializers);
            }
            return init;
        }


        protected virtual Expression VisitNewArray(NewArrayExpression na)
        {
            IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
            if (exprs != na.Expressions)
            {
                if (na.NodeType == ExpressionType.NewArrayInit)
                {
                    return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
                }
                else
                {
                    return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
                }
            }
            return na;
        }


        protected virtual Expression VisitInvocation(InvocationExpression iv)
        {
            IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
            Expression expr = this.Visit(iv.Expression);
            if (args != iv.Arguments || expr != iv.Expression)
            {
                return Expression.Invoke(expr, args);
            }
            return iv;
        }
    }
}

Now create a class called Query, and add the following code.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Collections;

namespace LINQTest
{
    public class Query<T> : IQueryable<T>, IQueryable, IEnumerable<T>, IEnumerable, IOrderedQueryable<T>, IOrderedQueryable
    {
        QueryProvider provider;
        Expression expression;


        public Query(QueryProvider provider)
        {
            if (provider == null)
            {
                throw new ArgumentNullException("provider");
            }
            this.provider = provider;
            this.expression = Expression.Constant(this);
        }


        public Query(QueryProvider provider, Expression expression)
        {
            if (provider == null)
            {
                throw new ArgumentNullException("provider");
            }
            if (expression == null)
            {
                throw new ArgumentNullException("expression");
            }
            if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
            {
                throw new ArgumentOutOfRangeException("expression");
            }
            this.provider = provider;
            this.expression = expression;
        }


        Expression IQueryable.Expression
        {
            get { return this.expression; }
        }


        Type IQueryable.ElementType
        {
            get { return typeof(T); }
        }


        IQueryProvider IQueryable.Provider
        {
            get { return this.provider; }
        }


        public IEnumerator<T> GetEnumerator()
        {
            return ((IEnumerable<T>)this.provider.Execute(this.expression)).GetEnumerator();
        }


        IEnumerator IEnumerable.GetEnumerator()
        {
            return ((IEnumerable)this.provider.Execute(this.expression)).GetEnumerator();
        }


        public override string ToString()
        {
            return this.provider.GetQueryText(this.expression);
        }
    }
}

 Now create a class called QueryProvider, again adding to it the following code.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;

namespace LINQTest
{
    public abstract class QueryProvider : IQueryProvider
    {
        protected QueryProvider()
        {
        }


        IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression)
        {
            return new Query<S>(this, expression);
        }


        IQueryable IQueryProvider.CreateQuery(Expression expression)
        {
            Type elementType = TypeSystem.GetElementType(expression.Type);
            try
            {
                return (IQueryable)Activator.CreateInstance(typeof(Query<>).MakeGenericType(elementType), new object[] { this, expression });
            }
            catch (TargetInvocationException tie)
            {
                throw tie.InnerException;
            }
        }


        S IQueryProvider.Execute<S>(Expression expression)
        {
            return (S)this.Execute(expression);
        }


        object IQueryProvider.Execute(Expression expression)
        {
            return this.Execute(expression);
        }


        public abstract string GetQueryText(Expression expression);
        public abstract object Execute(Expression expression);
    }
}

 Now create a class called Evaluator, and add the following code

 using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;

namespace LINQTest
{
    public static class Evaluator
    {
        /// <summary>
        /// Performs evaluation & replacement of independent sub-trees
        /// </summary>
        /// <param name="expression">The root of the expression tree.</param>
        /// <param name="fnCanBeEvaluated">A function that decides whether a given expression node can be part of the local function.</param>
        /// <returns>A new tree with sub-trees evaluated and replaced.</returns>
        public static Expression PartialEval(Expression expression, Func<Expression, bool> fnCanBeEvaluated)
        {
            return new SubtreeEvaluator(new Nominator(fnCanBeEvaluated).Nominate(expression)).Eval(expression);
        }


        /// <summary>
        /// Performs evaluation & replacement of independent sub-trees
        /// </summary>
        /// <param name="expression">The root of the expression tree.</param>
        /// <returns>A new tree with sub-trees evaluated and replaced.</returns>
        public static Expression PartialEval(Expression expression)
        {
            return PartialEval(expression, Evaluator.CanBeEvaluatedLocally);
        }


        private static bool CanBeEvaluatedLocally(Expression expression)
        {
            return expression.NodeType != ExpressionType.Parameter;
        }

    }
    /// <summary>
    /// Evaluates & replaces sub-trees when first candidate is reached (top-down)
    /// </summary>
    class SubtreeEvaluator : ExpressionVisitor
    {
        HashSet<Expression> candidates;


        internal SubtreeEvaluator(HashSet<Expression> candidates)
        {
            this.candidates = candidates;
        }


        internal Expression Eval(Expression exp)
        {
            return this.Visit(exp);
        }


        protected override Expression Visit(Expression exp)
        {
            if (exp == null)
            {
                return null;
            }
            if (this.candidates.Contains(exp))
            {
                return this.Evaluate(exp);
            }
            return base.Visit(exp);
        }


        private Expression Evaluate(Expression e)
        {
            if (e.NodeType == ExpressionType.Constant)
            {
                return e;
            }
            LambdaExpression lambda = Expression.Lambda(e);
            Delegate fn = lambda.Compile();
            return Expression.Constant(fn.DynamicInvoke(null), e.Type);
        }
    }

    /// <summary>
    /// Performs bottom-up analysis to determine which nodes can possibly
    /// be part of an evaluated sub-tree.
    /// </summary>
    class Nominator : ExpressionVisitor
    {
        Func<Expression, bool> fnCanBeEvaluated;
        HashSet<Expression> candidates;
        bool cannotBeEvaluated;


        internal Nominator(Func<Expression, bool> fnCanBeEvaluated)
        {
            this.fnCanBeEvaluated = fnCanBeEvaluated;
        }


        internal HashSet<Expression> Nominate(Expression expression)
        {
            this.candidates = new HashSet<Expression>();
            this.Visit(expression);
            return this.candidates;
        }


        protected override Expression Visit(Expression expression)
        {
            if (expression != null)
            {
                bool saveCannotBeEvaluated = this.cannotBeEvaluated;
                this.cannotBeEvaluated = false;
                base.Visit(expression);
                if (!this.cannotBeEvaluated)
                {
                    if (this.fnCanBeEvaluated(expression))
                    {
                        this.candidates.Add(expression);
                    }
                    else
                    {
                        this.cannotBeEvaluated = true;
                    }
                }
                this.cannotBeEvaluated |= saveCannotBeEvaluated;
            }
            return expression;
        }
    }
}

The next class deals with the the Select implementation on the Expression. As I said before, it's not complete, but nessesary to include as it is referenced, and this project is a work in progress, so create a class named ISelect, and copy in the following code..

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;

namespace LINQTest
{
    public abstract class ProjectionRow
    {
        public abstract object GetValue(int index);
    }

    internal class ColumnProjection
    {
        internal string Columns;
        internal Expression Selector;
    }


    internal class ColumnProjector : ExpressionVisitor
    {
        StringBuilder sb;
        int iColumn;
        ParameterExpression row;
        static MethodInfo miGetValue;


        internal ColumnProjector()
        {
            if (miGetValue == null)
            {
                miGetValue = typeof(ProjectionRow).GetMethod("GetValue");
            }
        }


        internal ColumnProjection ProjectColumns(Expression expression, ParameterExpression row)
        {
            this.sb = new StringBuilder();
            this.row = row;
            Expression selector = this.Visit(expression);
            return new ColumnProjection { Columns = this.sb.ToString(), Selector = selector };
        }


        protected override Expression VisitMemberAccess(MemberExpression m)
        {
            if (m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter)
            {
                if (this.sb.Length > 0)
                {
                    this.sb.Append(", ");
                }
                this.sb.Append(m.Member.Name);
                return Expression.Convert(Expression.Call(this.row, miGetValue, Expression.Constant(iColumn++)), m.Type);
            }
            else
            {
                return base.VisitMemberAccess(m);
            }
        }
    }
}


Phew, all that is done. All we have left now is the actual data class object, and a couple of helper classes to populate it generically. I will get to that in the next session.

Thanks,



No comments:

Post a Comment