Anda di halaman 1dari 7

Generating SQL From LINQ Expression Trees

An expression tree is na abstract representation of code as data. In .NET they are primarily use for LINQ-
style code. In C#, lambda expressions can be decomposed into expression trees. Here is an example of a
lambda expression and its expression tree:

x => x.PosId == 1 && x.Name == “Main”

There are five key types of expression tree nodes.

UnaryExpression: An operation with a single operand such as negation.

BinaryExpression: An operation with two operands such as addition or ‘&&’ or ‘||’.

MemberExpression: Accessing a property, field, or method of an object or a variable. (Variable references in lambda
expressions are implemented as fields on a classe generated by the compiler.)

ConstantExpression: A node that is a constant value.

ParameterExpression: An input to a lambda function.

The following code recursively walks an expression tree and generates the equivalent where clause in SQL, for
sufficiently simple expressions. One of the areas that was a bit tricky is SQL’s handling of NULL. I have to check the
right side of a binary expression for NULL so I can generate “x IS NULL” instead of “x = NULL”. I used parentheses
liberally so ease composing the expressions. Handling negation was done naively. It could be cleaned up by
propagating the negation into the child node.

Input Output
x => x.PosId == 1 ([PosId] = 1)
x => x.IsAborted ([IsAborted] = 1)
x => !x.IsAborted (NOT ([IsAborted] = 1))
x => x.Name == null ([Name] IS NULL)
x => x.PosId == posId (where posId = 2) ([PosId] = 2)
x => x.PosId == 1 && x.Name == “Main” (([PosId] = 1) AND ([Name] = ‘Main’))
x => x.Name.Contains(“Main”) ([Name] LIKE ‘%Main%’)
x => x.Name.StartsWith(“R”) ([Name] LIKE ‘R%’)
x => list.Contains(x.PosId) ([PosId] IN (1, 2, 3))

public class WhereBuilder


{
private readonly IProvider _provider;
private TableDefinition _tableDef;

public WhereBuilder(IProvider provider)


{
_provider = provider;
}

public string ToSql<T>(Expression<Func<T, bool>> expression)


{
_tableDef = _provider.GetTableDefinitionFor<T>();
return Recurse(expression.Body, true);
}

private string Recurse(Expression expression, bool isUnary = false, bool quote = true)
{
if (expression is UnaryExpression)
{
var unary = (UnaryExpression)expression;
var right = Recurse(unary.Operand, true);
return "(" + NodeTypeToString(unary.NodeType, right == "NULL") + " " + right + ")";
}
if (expression is BinaryExpression)
{
var body = (BinaryExpression)expression;
var right = Recurse(body.Right);
return "(" + Recurse(body.Left) + " " + NodeTypeToString(body.NodeType, right == "NULL") + " " +
right + ")";
}
if (expression is ConstantExpression)
{
var constant = (ConstantExpression)expression;
return ValueToString(constant.Value, isUnary, quote);
}
if (expression is MemberExpression)
{
var member = (MemberExpression)expression;

if (member.Member is PropertyInfo)
{
var property = (PropertyInfo)member.Member;
var colName = _tableDef.GetColumnNameFor(property.Name);
if (isUnary && member.Type == typeof(bool))
{
return "([" + colName + "] = 1)";
}
return "[" + colName + "]";
}
if (member.Member is FieldInfo)
{
return ValueToString(GetValue(member), isUnary, quote);
}
throw new Exception($"Expression does not refer to a property or field: {expression}");
}
if (expression is MethodCallExpression)
{
var methodCall = (MethodCallExpression)expression;
// LIKE queries:
if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
{
return "(" + Recurse(methodCall.Object) + " LIKE '%" + Recurse(methodCall.Arguments[0],
quote: false) + "%')";
}
if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) }))
{
return "(" + Recurse(methodCall.Object) + " LIKE '" + Recurse(methodCall.Arguments[0], quote:
false) + "%')";
}
if (methodCall.Method == typeof(string).GetMethod("EndsWith", new [] {typeof(string)}))
{
return "(" + Recurse(methodCall.Object) + " LIKE '%" + Recurse(methodCall.Arguments[0],
quote: false) + "')";
}
// IN queries:
if (methodCall.Method.Name == "Contains")
{
Expression collection;
Expression property;
if (methodCall.Method.IsDefined(typeof (ExtensionAttribute)) && methodCall.Arguments.Count ==
2)
{
collection = methodCall.Arguments[0];
property = methodCall.Arguments[1];
}
else if (!methodCall.Method.IsDefined(typeof (ExtensionAttribute)) &&
methodCall.Arguments.Count == 1)
{
collection = methodCall.Object;
property = methodCall.Arguments[0];
}
else
{
throw new Exception("Unsupported method call: " + methodCall.Method.Name);
}
var values = (IEnumerable)GetValue(collection);
var concated = "";
foreach (var e in values)
{
concated += ValueToString(e, false, true) + ", ";
}
if (concated == "")
{
return ValueToString(false, true, false);
}
return "(" + Recurse(property) + " IN (" + concated.Substring(0, concated.Length - 2) + "))";
}
throw new Exception("Unsupported method call: " + methodCall.Method.Name);
}
throw new Exception("Unsupported expression: " + expression.GetType().Name);
}

public string ValueToString(object value, bool isUnary, bool quote)


{
if (value is bool)
{
if (isUnary)
{
return (bool)value ? "(1=1)" : "(1=0)";
}
return (bool)value ? "1" : "0";
}
return _provider.ValueToString(value, quote);
}

private static bool IsEnumerableType(Type type)


{
return type
.GetInterfaces()
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof (IEnumerable<>));
}

private static object GetValue(Expression member)


{
// source: http://stackoverflow.com/a/2616980/291955
var objectMember = Expression.Convert(member, typeof(object));
var getterLambda = Expression.Lambda<Func<object>>(objectMember);
var getter = getterLambda.Compile();
return getter();
}

private static object NodeTypeToString(ExpressionType nodeType, bool rightIsNull)


{
switch (nodeType)
{
case ExpressionType.Add:
return "+";
case ExpressionType.And:
return "&";
case ExpressionType.AndAlso:
return "AND";
case ExpressionType.Divide:
return "/";
case ExpressionType.Equal:
return rightIsNull ? "IS" : "=";
case ExpressionType.ExclusiveOr:
return "^";
case ExpressionType.GreaterThan:
return ">";
case ExpressionType.GreaterThanOrEqual:
return ">=";
case ExpressionType.LessThan:
return "<";
case ExpressionType.LessThanOrEqual:
return "<=";
case ExpressionType.Modulo:
return "%";
case ExpressionType.Multiply:
return "*";
case ExpressionType.Negate:
return "-";
case ExpressionType.Not:
return "NOT";
case ExpressionType.NotEqual:
return "<>";
case ExpressionType.Or:
return "|";
case ExpressionType.OrElse:
return "OR";
case ExpressionType.Subtract:
return "-";
}
throw new Exception($"Unsupported node type: {nodeType}");
}
}

The following is a much improved version of the where clause builder. This version generates parameterized queries
so it isn’t vulnerable to SQL injection. Using parameters also allowed me to simply the logic since I don’t need to
worry about stringifying the values or using “IS” instead of “=” for null checking.

I moved all the string concatenation into a separate class called WherePart. These objects are composable in a
structure similar to the source expression tree. Extracting this class is my favorite part of the refactoring.

I’m still not happy with how I’m handling the LIKE queries. I have to pass a prefix and postfix parameter down to the
next level of recursion which clutters up the method signature. It might be better to just build the string in place.
public class WhereBuilder
{
private readonly IProvider _provider;
private TableDefinition _tableDef;

public WhereBuilder(IProvider provider)


{
_provider = provider;
}

public WherePart ToSql<T>(Expression<Func<T, bool>> expression)


{
_tableDef = _provider.GetTableDefinitionFor<T>();
var i = 1;
return Recurse(ref i, expression.Body, isUnary: true);
}

WherePart Recurse(
ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null)
{
if (expression is UnaryExpression)
{
var unary = (UnaryExpression)expression;
return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true));
}

if (expression is BinaryExpression)
{
var body = (BinaryExpression)expression;
return WherePart.Concat(
Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right));
}

if (expression is ConstantExpression)
{
var constant = (ConstantExpression)expression;
var value = constant.Value;
if (value is int)
return WherePart.IsSql(value.ToString());

if (value is string)
value = prefix + (string)value + postfix;

if (value is bool && isUnary)


return WherePart.Concat(WherePart.IsParameter(i++, value), "=", WherePart.IsSql("1"));
return WherePart.IsParameter(i++, value);
}

if (expression is MemberExpression)
{
var member = (MemberExpression)expression;

if (member.Member is PropertyInfo)
{
var property = (PropertyInfo)member.Member;
var colName = _tableDef.GetColumnNameFor(property.Name);
if (isUnary && member.Type == typeof(bool))
return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++,
true));
return WherePart.IsSql("[" + colName + "]");
}

if (member.Member is FieldInfo)
{
var value = GetValue(member);
if (value is string)
value = prefix + (string)value + postfix;
return WherePart.IsParameter(i++, value);
}

throw new Exception($"Expression does not refer to a property or field: {expression}");


}

if (expression is MethodCallExpression)
{
var methodCall = (MethodCallExpression)expression;
// LIKE queries:
if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
return WherePart.Concat(
Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0],
prefix: "%", postfix: "%"));

if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) }))


return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i,
methodCall.Arguments[0], postfix: "%"));

if (methodCall.Method == typeof(string).GetMethod("EndsWith", new [] {typeof(string)}))


return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i,
methodCall.Arguments[0], prefix: "%"));

// IN queries:
if (methodCall.Method.Name == "Contains")
{
Expression collection;
Expression property;
if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count ==
2)
{
collection = methodCall.Arguments[0];
property = methodCall.Arguments[1];
}

else if (!methodCall.Method.IsDefined(typeof(ExtensionAttribute)) &&


methodCall.Arguments.Count == 1)
{
collection = methodCall.Object;
property = methodCall.Arguments[0];
}

else
throw new Exception("Unsupported method call: " + methodCall.Method.Name);

var values = (IEnumerable)GetValue(collection);


return WherePart.Concat(Recurse(ref i, property), "IN", WherePart.IsCollection(ref i,
values));
}
throw new Exception("Unsupported method call: " + methodCall.Method.Name);
}
throw new Exception("Unsupported expression: " + expression.GetType().Name);
}

public string ValueToString(object value, bool isUnary, bool quote)


{
if (value is bool)
{
if (isUnary)
return (bool)value ? "(1=1)" : "(1=0)";
return (bool)value ? "1" : "0";
}
return _provider.ValueToString(value, quote);
}

private static object GetValue(Expression member)


{
// source: http://stackoverflow.com/a/2616980/291955
var objectMember = Expression.Convert(member, typeof(object));
var getterLambda = Expression.Lambda<Func<object>>(objectMember);
var getter = getterLambda.Compile();
return getter();
}

private static string NodeTypeToString(ExpressionType nodeType)


{
switch (nodeType)
{
case ExpressionType.Add:
return "+";
case ExpressionType.And:
return "&";
case ExpressionType.AndAlso:
return "AND";
case ExpressionType.Divide:
return "/";
case ExpressionType.Equal:
return "=";
case ExpressionType.ExclusiveOr:
return "^";
case ExpressionType.GreaterThan:
return ">";
case ExpressionType.GreaterThanOrEqual:
return ">=";
case ExpressionType.LessThan:
return "<";
case ExpressionType.LessThanOrEqual:
return "<=";
case ExpressionType.Modulo:
return "%";
case ExpressionType.Multiply:
return "*";
case ExpressionType.Negate:
return "-";
case ExpressionType.Not:
return "NOT";
case ExpressionType.NotEqual:
return "<>";
case ExpressionType.Or:
return "|";
case ExpressionType.OrElse:
return "OR";
case ExpressionType.Subtract:
return "-";
}
throw new Exception($"Unsupported node type: {nodeType}");
}
}

public class WherePart


{
public string Sql { get; set; }
public Dictionary<string, object> Parameters { get; set; } = new Dictionary<string, object>();

public static WherePart IsSql(string sql)


{
return new WherePart()
{
Parameters = new Dictionary<string, object>(),
Sql = sql
};
}

public static WherePart IsParameter(int count, object value)


{
return new WherePart()
{
Parameters = {{count.ToString(), value}},
Sql = $"@{count}"
};
}

public static WherePart IsCollection(ref int countStart, IEnumerable values)


{
var parameters = new Dictionary<string, object>();
var sql = new StringBuilder("(");
foreach (var value in values)
{
parameters.Add((countStart).ToString(), value);
sql.Append($"@{countStart},");
countStart++;
}
if (sql.Length == 1)
{
sql.Append("null,");
}
sql[sql.Length-1] = ')';
return new WherePart()
{
Parameters = parameters,
Sql = sql.ToString()
};
}

public static WherePart Concat(string @operator, WherePart operand)


{
return new WherePart()
{
Parameters = operand.Parameters,
Sql = $"({@operator} {operand.Sql})"
};
}

public static WherePart Concat(WherePart left, string @operator, WherePart right)


{
return new WherePart()
{
Parameters = left.Parameters.Union(right.Parameters).ToDictionary(kvp => kvp.Key, kvp =>
kvp.Value),
Sql = $"({left.Sql} {@operator} {right.Sql})"
};
}
}

Anda mungkin juga menyukai