diff --git a/src/NServiceKit.OrmLite/Expressions/SqlExpressionVisitor.cs b/src/NServiceKit.OrmLite/Expressions/SqlExpressionVisitor.cs index 30b9c5d..2628d32 100644 --- a/src/NServiceKit.OrmLite/Expressions/SqlExpressionVisitor.cs +++ b/src/NServiceKit.OrmLite/Expressions/SqlExpressionVisitor.cs @@ -18,7 +18,7 @@ public abstract class SqlExpressionVisitor private Expression> underlyingExpression; /// The order by properties. - private List orderByProperties = new List(); + private List orderByProperties = new List(); /// The select expression. private string selectExpression = string.Empty; @@ -54,7 +54,7 @@ public abstract class SqlExpressionVisitor /// Gets or sets a value indicating whether the prefix field with table name. /// /// true if prefix field with table name, false if not. - public bool PrefixFieldWithTableName {get;set;} + public bool PrefixFieldWithTableName { get; set; } /// /// Gets or sets a value indicating whether the where statement without where string. @@ -300,7 +300,7 @@ public virtual SqlExpressionVisitor OrderBy(Expression> k sep = string.Empty; useFieldName = true; orderByProperties.Clear(); - var property = Visit(keySelector).ToString(); + var property = Visit(keySelector).ToString(); orderByProperties.Add(property + " ASC"); BuildOrderByClauseInternal(); return this; @@ -310,7 +310,7 @@ public virtual SqlExpressionVisitor OrderBy(Expression> k /// Type of the key. /// The key selector. /// A SqlExpressionVisitor<T> - public virtual SqlExpressionVisitor ThenBy (Expression> keySelector) + public virtual SqlExpressionVisitor ThenBy(Expression> keySelector) { sep = string.Empty; useFieldName = true; @@ -355,7 +355,7 @@ private void BuildOrderByClauseInternal() if (orderByProperties.Count > 0) { orderBy = "ORDER BY "; - foreach(var prop in orderByProperties) + foreach (var prop in orderByProperties) { orderBy += prop + ","; } @@ -773,7 +773,7 @@ protected virtual object VisitBinary(BinaryExpression b) if (operand == "AND" || operand == "OR") { var m = b.Left as MemberExpression; - if (m != null && m.Expression != null + if (m != null && m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter) left = new PartialSqlString(string.Format("{0}={1}", VisitMemberAccess(m), GetQuotedTrueValue())); else @@ -792,8 +792,8 @@ protected virtual object VisitBinary(BinaryExpression b) return new PartialSqlString(OrmLiteConfig.DialectProvider.GetQuotedValue(result, result.GetType())); } - if(left as PartialSqlString == null) - left = ((bool) left) ? GetTrueExpression() : GetFalseExpression(); + if (left as PartialSqlString == null) + left = ((bool)left) ? GetTrueExpression() : GetFalseExpression(); if (right as PartialSqlString == null) right = ((bool)right) ? GetTrueExpression() : GetFalseExpression(); } @@ -835,7 +835,7 @@ protected virtual object VisitBinary(BinaryExpression b) left = OrmLiteConfig.DialectProvider.GetQuotedValue(left, left != null ? left.GetType() : null); else if (right as PartialSqlString == null) right = OrmLiteConfig.DialectProvider.GetQuotedValue(right, right != null ? right.GetType() : null); - + } if (operand == "=" && right.ToString().Equals("null", StringComparison.InvariantCultureIgnoreCase)) operand = "is"; @@ -847,7 +847,7 @@ protected virtual object VisitBinary(BinaryExpression b) case "COALESCE": return new PartialSqlString(string.Format("{0}({1},{2})", operand, left, right)); default: - return new PartialSqlString("(" + left + sep + operand + sep + right +")"); + return new PartialSqlString("(" + left + sep + operand + sep + right + ")"); } } @@ -864,7 +864,7 @@ protected virtual object VisitMemberAccess(MemberExpression m) if (propertyInfo.PropertyType.IsEnum) return new EnumMemberAccess((PrefixFieldWithTableName ? OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef.ModelName) + "." : "") + GetQuotedColumnName(m.Member.Name), propertyInfo.PropertyType); - return new PartialSqlString((PrefixFieldWithTableName ? OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef.ModelName)+"." : "") + GetQuotedColumnName(m.Member.Name)); + return new PartialSqlString((PrefixFieldWithTableName ? OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef.ModelName) + "." : "") + GetQuotedColumnName(m.Member.Name)); } var member = Expression.Convert(m, typeof(object)); @@ -961,7 +961,7 @@ protected virtual object VisitUnary(UnaryExpression u) return !((bool)o); if (IsFieldName(o)) - o = o + "=" + GetQuotedTrueValue(); + o = o + "=" + GetQuotedTrueValue(); return new PartialSqlString("NOT (" + o + ")"); case ExpressionType.Convert: @@ -983,8 +983,8 @@ private bool IsColumnAccess(MethodCallExpression m) return IsColumnAccess(m.Object as MethodCallExpression); var exp = m.Object as MemberExpression; - return exp != null - && exp.Expression != null + return exp != null + && exp.Expression != null && exp.Expression.Type == typeof(T) && exp.Expression.NodeType == ExpressionType.Parameter; } @@ -997,7 +997,7 @@ protected virtual object VisitMethodCall(MethodCallExpression m) if (m.Method.DeclaringType == typeof(Sql)) return VisitSqlMethodCall(m); - if (IsArrayMethod(m)) + if (IsArrayMethod(m)) return VisitArrayMethodCall(m); if (IsColumnAccess(m)) @@ -1159,7 +1159,7 @@ protected object GetFalseExpression() /// The quoted true value. protected static object GetQuotedTrueValue() { - return new PartialSqlString(OrmLiteConfig.DialectProvider.GetQuotedValue(true, typeof (bool))); + return new PartialSqlString(OrmLiteConfig.DialectProvider.GetQuotedValue(true, typeof(bool))); } /// Gets quoted false value. @@ -1195,7 +1195,7 @@ public IList GetAllFields() /// A string. protected virtual string ApplyPaging(string sql) { - sql = sql + (string.IsNullOrEmpty(LimitExpression) ? "" :"\n" + LimitExpression); + sql = sql + (string.IsNullOrEmpty(LimitExpression) ? "" : "\n" + LimitExpression); return sql; } @@ -1390,6 +1390,19 @@ protected virtual object VisitColumnAccessMethod(MethodCallExpression m) } return new PartialSqlString(statement); } + + public string ToInsertWhereNotExistsStatement(object obj) + { + var sql = new StringBuilder(); + var fieldDefs = modelDef.FieldDefinitions.Where(t => t.AutoIncrement == false && t.IsComputed == false).ToList(); + sql.AppendFormat("Insert Into {0} ({1})", OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef), string.Join(",", fieldDefs.Select(t => OrmLiteConfig.DialectProvider.GetQuotedColumnName(t.FieldName)).ToArray())); + sql.Append("\n"); + sql.AppendFormat("Select {0}", string.Join(",", fieldDefs.Select(t => t.GetQuotedValue(obj)).ToArray())); + sql.Append("\n"); + BuildSelectExpression("1", false); + sql.AppendFormat(" where not exists({0} {1})", selectExpression, WhereExpression); + return sql.ToString(); + } } /// A partial SQL string. @@ -1431,7 +1444,8 @@ public class EnumMemberAccess : PartialSqlString /// illegal values. /// The text. /// The type of the enum. - public EnumMemberAccess(string text, Type enumType) : base(text) + public EnumMemberAccess(string text, Type enumType) + : base(text) { if (!enumType.IsEnum) throw new ArgumentException("Type not valid", "enumType"); diff --git a/src/NServiceKit.OrmLite/OrmLiteWriteConnectionExtensions.cs b/src/NServiceKit.OrmLite/OrmLiteWriteConnectionExtensions.cs index 9db95af..18b3749 100644 --- a/src/NServiceKit.OrmLite/OrmLiteWriteConnectionExtensions.cs +++ b/src/NServiceKit.OrmLite/OrmLiteWriteConnectionExtensions.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Data; using System.Linq; +using System.Linq.Expressions; namespace NServiceKit.OrmLite { @@ -297,6 +298,18 @@ public static void Insert(this IDbConnection dbConn, params T[] objs) dbConn.Exec(dbCmd => dbCmd.Insert(objs)); } + + /// An IDbConnection extension method that inserts where not exists. + /// Generic type parameter. + /// The dbConn to act on. + /// The object. + /// predicate to create where clause + public static void InsertWhereNotExists(this IDbConnection dbConn, T obj, Expression> wherePredicate) + where T : new() + { + dbConn.Exec(dbCmd => dbCmd.InsertWhereNotExists(obj,wherePredicate)); + } + /// An IDbConnection extension method that inserts all. /// Generic type parameter. /// The dbConn to act on. diff --git a/src/NServiceKit.OrmLite/OrmLiteWriteExtensions.cs b/src/NServiceKit.OrmLite/OrmLiteWriteExtensions.cs index 7947275..1172a7d 100644 --- a/src/NServiceKit.OrmLite/OrmLiteWriteExtensions.cs +++ b/src/NServiceKit.OrmLite/OrmLiteWriteExtensions.cs @@ -15,6 +15,7 @@ using System.Data; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Text.RegularExpressions; using NServiceKit.Common.Utils; using NServiceKit.Logging; @@ -84,38 +85,38 @@ internal static void CreateTable(this IDbCommand dbCmd, bool overwrite, Type mod { var modelDef = modelType.GetModelDefinition(); - var dialectProvider = OrmLiteConfig.DialectProvider; - var tableName = dialectProvider.NamingStrategy.GetTableName(modelDef.ModelName); - var tableExists = dialectProvider.DoesTableExist(dbCmd, tableName); - if (overwrite && tableExists) + var dialectProvider = OrmLiteConfig.DialectProvider; + var tableName = dialectProvider.NamingStrategy.GetTableName(modelDef.ModelName); + var tableExists = dialectProvider.DoesTableExist(dbCmd, tableName); + if (overwrite && tableExists) { DropTable(dbCmd, modelDef); - tableExists = false; + tableExists = false; } try { - if (!tableExists) - { - ExecuteSql(dbCmd, dialectProvider.ToCreateTableStatement(modelType)); - - var sqlIndexes = dialectProvider.ToCreateIndexStatements(modelType); - foreach (var sqlIndex in sqlIndexes) - { - try - { - dbCmd.ExecuteSql(sqlIndex); - } - catch (Exception exIndex) - { - if (IgnoreAlreadyExistsError(exIndex)) - { - Log.DebugFormat("Ignoring existing index '{0}': {1}", sqlIndex, exIndex.Message); - continue; - } - throw; - } - } + if (!tableExists) + { + ExecuteSql(dbCmd, dialectProvider.ToCreateTableStatement(modelType)); + + var sqlIndexes = dialectProvider.ToCreateIndexStatements(modelType); + foreach (var sqlIndex in sqlIndexes) + { + try + { + dbCmd.ExecuteSql(sqlIndex); + } + catch (Exception exIndex) + { + if (IgnoreAlreadyExistsError(exIndex)) + { + Log.DebugFormat("Ignoring existing index '{0}': {1}", sqlIndex, exIndex.Message); + continue; + } + throw; + } + } var sequenceList = dialectProvider.SequenceList(modelType); if (sequenceList.Count > 0) @@ -150,7 +151,7 @@ internal static void CreateTable(this IDbCommand dbCmd, bool overwrite, Type mod } } } - } + } } catch (Exception ex) { @@ -246,14 +247,14 @@ private static bool IgnoreAlreadyExistsError(Exception ex) //ignore Sqlite table already exists error const string sqliteAlreadyExistsError = "already exists"; const string sqlServerAlreadyExistsError = "There is already an object named"; - return ex.Message.Contains(sqliteAlreadyExistsError) - || ex.Message.Contains(sqlServerAlreadyExistsError) ; + return ex.Message.Contains(sqliteAlreadyExistsError) + || ex.Message.Contains(sqlServerAlreadyExistsError); } /// DEFINE GENERATOR failed. /// The ex. /// true if it succeeds, false if it fails. - private static bool IgnoreAlreadyExistsGeneratorError(Exception ex) + private static bool IgnoreAlreadyExistsGeneratorError(Exception ex) { const string fbError = "attempt to store duplicate value"; return ex.Message.Contains(fbError); @@ -288,7 +289,7 @@ public static int GetColumnIndex(this IDataReader dataReader, string fieldName) } /// The not found. - private const int NotFound = -1; + private const int NotFound = -1; /// A T extension method that populate with SQL reader. /// Generic type parameter. @@ -299,10 +300,10 @@ public static int GetColumnIndex(this IDataReader dataReader, string fieldName) /// A T. public static T PopulateWithSqlReader(this T objWithProperties, IDataReader dataReader, FieldDefinition[] fieldDefs, Dictionary indexCache) { - try - { - foreach (var fieldDef in fieldDefs) - { + try + { + foreach (var fieldDef in fieldDefs) + { int index; if (indexCache != null) { @@ -325,17 +326,17 @@ public static T PopulateWithSqlReader(this T objWithProperties, IDataReader d index = TryGuessColumnIndex(fieldDef.FieldName, dataReader); } } - - if (index == NotFound) continue; - var value = dataReader.GetValue(index); - fieldDef.SetValue(objWithProperties, value); - } - } - catch (Exception ex) - { - Log.Error(ex); - } - return objWithProperties; + + if (index == NotFound) continue; + var value = dataReader.GetValue(index); + fieldDef.SetValue(objWithProperties, value); + } + } + catch (Exception ex) + { + Log.Error(ex); + } + return objWithProperties; } /// The allowed property characters regular expression. @@ -413,24 +414,24 @@ private static int TryGuessColumnIndex(string fieldName, IDataReader dataReader) /// The dbCmd to act on. /// The objects. internal static void Update(this IDbCommand dbCmd, params T[] objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToUpdateRowStatement(obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToUpdateRowStatement(obj)); + } + } /// An IDbCommand extension method that updates all. /// Generic type parameter. /// The dbCmd to act on. /// The objects. internal static void UpdateAll(this IDbCommand dbCmd, IEnumerable objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToUpdateRowStatement(obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToUpdateRowStatement(obj)); + } + } /// An IDbConnection extension method that creates update statement. /// Generic type parameter. @@ -447,24 +448,24 @@ internal static IDbCommand CreateUpdateStatement(this IDbConnection connectio /// The dbCmd to act on. /// The objects. internal static void Delete(this IDbCommand dbCmd, params T[] objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteRowStatement(obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteRowStatement(obj)); + } + } /// An IDbCommand extension method that deletes all described by dbCmd. /// Generic type parameter. /// The dbCmd to act on. /// The objects. internal static void DeleteAll(this IDbCommand dbCmd, IEnumerable objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteRowStatement(obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteRowStatement(obj)); + } + } /// An IDbCommand extension method that deletes the by identifier. /// Generic type parameter. @@ -510,7 +511,7 @@ internal static void DeleteByIds(this IDbCommand dbCmd, IEnumerable idValues) internal static void DeleteByIdParam(this IDbCommand dbCmd, object id) { var modelDef = ModelDefinition.Definition; - var idParamString = OrmLiteConfig.DialectProvider.ParamString+"0"; + var idParamString = OrmLiteConfig.DialectProvider.ParamString + "0"; var sql = string.Format("DELETE FROM {0} WHERE {1} = {2}", OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef), @@ -521,7 +522,7 @@ internal static void DeleteByIdParam(this IDbCommand dbCmd, object id) idParam.ParameterName = idParamString; idParam.Value = id; dbCmd.Parameters.Add(idParam); - + dbCmd.ExecuteSql(sql); } @@ -538,8 +539,8 @@ internal static void DeleteAll(this IDbCommand dbCmd) /// Type of the table. internal static void DeleteAll(this IDbCommand dbCmd, Type tableType) { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteStatement(tableType, null)); - } + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToDeleteStatement(tableType, null)); + } /// An IDbCommand extension method that deletes this object. /// Generic type parameter. @@ -584,24 +585,36 @@ internal static void Save(this IDbCommand dbCmd, T obj) /// The dbCmd to act on. /// The objects. internal static void Insert(this IDbCommand dbCmd, params T[] objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToInsertRowStatement(dbCmd, obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToInsertRowStatement(dbCmd, obj)); + } + } + + /// An IDbCommand extension method that inserts where not exists. + /// Generic type parameter. + /// The dbCmd to act on. + /// The object. + /// the predicate to generate where clause + internal static void InsertWhereNotExists(this IDbCommand dbCmd, T obj, Expression> predicate) + { + var ev = OrmLiteConfig.DialectProvider.ExpressionVisitor(); + var sql = ev.Where(predicate).ToInsertWhereNotExistsStatement(obj); + dbCmd.ExecuteSql(sql); + } /// An IDbCommand extension method that inserts all. /// Generic type parameter. /// The dbCmd to act on. /// The objects. internal static void InsertAll(this IDbCommand dbCmd, IEnumerable objs) - { - foreach (var obj in objs) - { - dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToInsertRowStatement(dbCmd, obj)); - } - } + { + foreach (var obj in objs) + { + dbCmd.ExecuteSql(OrmLiteConfig.DialectProvider.ToInsertRowStatement(dbCmd, obj)); + } + } /// An IDbConnection extension method that creates insert statement. /// Generic type parameter. @@ -691,10 +704,10 @@ internal static IDbTransaction BeginTransaction(this IDbCommand dbCmd, Isolation /// The object. internal static void ExecuteProcedure(this IDbCommand dbCommand, T obj) { - string sql = OrmLiteConfig.DialectProvider.ToExecuteProcedureStatement(obj); - dbCommand.CommandType= CommandType.StoredProcedure; - dbCommand.ExecuteSql(sql); - } - + string sql = OrmLiteConfig.DialectProvider.ToExecuteProcedureStatement(obj); + dbCommand.CommandType = CommandType.StoredProcedure; + dbCommand.ExecuteSql(sql); + } + } }