[C#]Dapperに似た軽量ORマッパーを自作した話

プロダクトの大事な基幹であるDBアクセス処理。データをシームレスに扱うためにはORマッパーは欠かせない存在となってきています。

今回、ORマッパーにDapperを使おうと検討したところ、思わぬ壁に当たったため自作してみました。

要件

1.RubyやLaravelのようにリッチなものは必要ない。
2.Modelとテーブルのカラム名が全く違っても使える(!)
3.とりあえずシンプルなCRUDが達成できればいい。

この(2)が曲者で、 Dapperでは直感的に実装することができなくなってしまいます。スネークケースをキャメルケースに直すくらいならできるみたいですが。実案件ではどうしてもここが障壁になることがあると思います。

個人的にはこれくらい薄いラッパーが使いやすくて好きです。特に帳票やデータ解析が絡むとSQLを書く前提の方が管理しやすかったり。シンプルなCRUDだけで事足りるならSQLを書くことはむしろ保守性を下げる要因とみなされるでしょうけど。

nuGetの追加

必要なパッケージは「Microsoft.Data.SqlClient」と「System.Reflection」です。それぞれインストールしてください。

マッパークラスの作成

早速コード。

using Microsoft.Data.SqlClient;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace ネームスペース
{
    public static class Mapper
    {

        public static List<Model> Query<Model>(this SqlConnection connection, SqlTransaction? transaction, string sql,  params (string paramName,object value)[] parameters) where Model : class, new()
        {
            var param = new List<SqlParameter>();
            foreach(var item in parameters)
            {
                param.Add(new SqlParameter( item.paramName, item.value));
            }
            return Query<Model>(connection, transaction, sql, param);
        }
        public static List<Model> Query<Model>(this SqlConnection connection, SqlTransaction? transaction, string sql, List<SqlParameter> parameters) where Model : class, new()
        {
            List<Model> models = [];
            var modelType = typeof(Model);
            var propertyMaps = GetPropertyAndMapColumnNames(modelType);

            using (SqlCommand command = new SqlCommand(sql, connection))
            {   
                if(transaction != null)
                    command.Transaction = transaction;

                if (parameters != null)
                {
                    foreach (var parameter in parameters)
                    {
                        command.Parameters.Add(parameter);
                    }
                }
                using (SqlDataReader reader = command.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        var model = new Model();
                        foreach (var pair in propertyMaps)
                        {
                            pair.Key.SetValue(model, reader[pair.Value.Name]);
                        }
                        models.Add(model);
                    }
                }
            }
            return models;
        }

        public static List<Model> QueryWhere<Model>(this SqlConnection connection, SqlTransaction? transaction, string sqlWhere,  List<SqlParameter> parameters) where Model : class, new()
        {
            List<Model> models = [];
            var modelType = typeof(Model);
            var propertyMaps = GetPropertyAndMapColumnNames(modelType);

            using (SqlCommand command = new SqlCommand("select * from " + GetTableName(modelType) + " " + sqlWhere, connection))
            {
                if (transaction != null)
                    command.Transaction = transaction;

                if (parameters != null)
                {
                    foreach (var parameter in parameters)
                    {
                        command.Parameters.Add(parameter);
                    }
                }
                using (SqlDataReader reader = command.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        var model = new Model();
                        foreach (var pair in propertyMaps)
                        {
                            pair.Key.SetValue(model, reader[pair.Value.Name]);
                        }
                        models.Add(model);
                    }
                }
            }
            return models;
        }

        public static List<SqlParameter> CreateParams(params (string name, object value)[] parameters)
        {
            List<SqlParameter> sqlParameters = new List<SqlParameter>();
            foreach (var param in parameters)
            {
                sqlParameters.Add(new SqlParameter() { ParameterName = param.name, Value = param.value });
            }
            return sqlParameters;
        }

        private static Dictionary<PropertyInfo, MapColumn> GetPropertyAndMapColumnNames(Type type)
        {
            Dictionary<PropertyInfo, MapColumn> propertyNameAndColumnNames = new Dictionary<PropertyInfo, MapColumn>();
            foreach (var property in type.GetProperties())
            {
                if (property.IsDefined(typeof(MapColumnAttribute), true))
                {
                    bool isPrimaryKey = false;
                    MapColumnAttribute attribute = (MapColumnAttribute)property.GetCustomAttribute(typeof(MapColumnAttribute), true)!;
                    if (property.IsDefined(typeof(MapPrimaryKeyAttribute), true))
                    {
                        isPrimaryKey = true;
                    }
                    propertyNameAndColumnNames.Add(property, new MapColumn(attribute.ColumnName, isPrimaryKey));
                }
            }
            return propertyNameAndColumnNames;
        }

        private static string GetTableName(Type type)
        {
            if (type.IsDefined(typeof(MapTableAttribute), true))
            {
                MapTableAttribute attribute = (MapTableAttribute)type.GetCustomAttribute(typeof(MapTableAttribute), true)!;
                return attribute.TableName;
            }
            return "";
        }

        public static int Insert<Model>(this SqlConnection connection, SqlTransaction? transaction, Model model ) where Model : class, new()
        {
            var modelType = typeof(Model);
            var propertyMaps = GetPropertyAndMapColumnNames(modelType);
            var tableName = GetTableName(modelType);
            var parameters = new List<SqlParameter>();

            string sql = "";
            sql += "insert into " + tableName + " values (";
            string sqlValues = "";
            foreach (var pair in propertyMaps)
            {
                if (!string.IsNullOrEmpty(sqlValues))
                    sqlValues += ", ";
                var value = pair.Key.GetValue(model);
                sqlValues += "@" + pair.Value.Name;
                parameters.Add(new SqlParameter("@" + pair.Value.Name, value));
            }
            sql += sqlValues + ")";

            int resultCount = 0;
            using (SqlCommand command = new SqlCommand(sql, connection))
            {
                if (transaction != null)
                    command.Transaction = transaction;

                foreach (var parameter in parameters)
                {
                    command.Parameters.Add(parameter);
                }
                resultCount = command.ExecuteNonQuery();
            }
            return resultCount;
        }

        public static int Update<Model>(this SqlConnection connection, SqlTransaction? transaction, Model model ) where Model : class, new()
        {
            var modelType = typeof(Model);
            var propertyMaps = GetPropertyAndMapColumnNames(modelType);
            var tableName = GetTableName(modelType);
            var parameters = new List<SqlParameter>();

            string sql = "";
            var sqlWhere = "";
            sql += "update " + tableName + " set ";
            string sqlValues = "";
            foreach (var pair in propertyMaps)
            {
                var value = pair.Key.GetValue(model);
                if (pair.Value.IsPrimaryKey)
                {
                    if (!string.IsNullOrEmpty(sqlWhere))
                        sqlWhere += " and ";

                    sqlWhere += pair.Value.Name + " = @" + pair.Value.Name;
                }
                else
                {
                    if (!string.IsNullOrEmpty(sqlValues))
                        sqlValues += ", ";

                    sqlValues += pair.Value.Name + " = @" + pair.Value.Name;
                }
                parameters.Add(new SqlParameter("@" + pair.Value.Name, value));
            }
            if (string.IsNullOrEmpty(sqlWhere)) throw new InvalidOperationException("not set Primarykey at Model class.");
            sql += sqlValues + " where " + sqlWhere;

            int resultCount = 0;
            using (SqlCommand command = new SqlCommand(sql, connection))
            {
                if (transaction != null)
                    command.Transaction = transaction;

                foreach (var parameter in parameters)
                {
                    command.Parameters.Add(parameter);
                }
                resultCount = command.ExecuteNonQuery();
            }
            return resultCount;
        }
        public static int Delete<Model>(this SqlConnection connection, SqlTransaction? transaction, Model model) where Model : class, new()
        {
            var modelType = typeof(Model);
            var propertyMaps = GetPropertyAndMapColumnNames(modelType);
            var tableName = GetTableName(modelType);
            var parameters = new List<SqlParameter>();

            string sql = "";
            var sqlWhere = "";
            sql += "delete from " + tableName;
            foreach (var pair in propertyMaps)
            {
                if (pair.Value.IsPrimaryKey)
                {
                    var value = pair.Key.GetValue(model);
                    if (!string.IsNullOrEmpty(sqlWhere))
                        sqlWhere += " and ";

                    sqlWhere += pair.Value.Name + " = @" + pair.Value.Name;
                    parameters.Add(new SqlParameter("@" + pair.Value.Name, value));
                }
            }
            if (string.IsNullOrEmpty(sqlWhere)) throw new InvalidOperationException("not set Primarykey at Model class.");
            sql += " where " + sqlWhere;

            int resultCount = 0;
            using (SqlCommand command = new SqlCommand(sql, connection))
            {
                if (transaction != null)
                    command.Transaction = transaction;

                foreach (var parameter in parameters)
                {
                    command.Parameters.Add(parameter);
                }
                resultCount = command.ExecuteNonQuery();
            }
            return resultCount;
        }

        public static int Delete<Model>(this SqlConnection connection, SqlTransaction? transaction, string whereBlock,  params SqlParameter[] parameters)
        {
            return Delete<Model>(connection, transaction, whereBlock,  parameters.ToList());
        }

        public static int Delete<Model>(this SqlConnection connection, SqlTransaction? transaction, string whereBlock, List<SqlParameter> parameters)
        {
            var modelType = typeof(Model);
            var tableName = GetTableName(modelType);

            if (string.IsNullOrEmpty(whereBlock))
                throw new InvalidOperationException("not set [where]block.");
            if (parameters.Count < 1)
                throw new InvalidOperationException("not set params");

            string sql = "delete from " + tableName + " where " + whereBlock;

            int resultCount = 0;
            using (SqlCommand command = new SqlCommand(sql, connection))
            {
                if (transaction != null)
                    command.Transaction = transaction;

                foreach (var parameter in parameters)
                {
                    command.Parameters.Add(parameter);
                }
                resultCount = command.ExecuteNonQuery();
            }
            return resultCount;
        }
    }

    public class MapColumn
    {
        public string Name { get; set; } = "";
        public bool IsPrimaryKey { get; set; } = false;

        public MapColumn(string name, bool isPrimaryKey = false)
        {
            Name = name;
            IsPrimaryKey = isPrimaryKey;
        }
    }

    public class MapTableAttribute : Attribute
    {
        public string TableName { get; } = "";
        public MapTableAttribute(string name)
        {
            TableName = name;
        }
    }

    public class MapColumnAttribute : Attribute
    {
        public string ColumnName { get; } = "";
        public MapColumnAttribute(string name)
        {
            ColumnName = name;
        }
    }

    public class MapPrimaryKeyAttribute : Attribute
    {
        public bool IsSet { get; } = true;
        public MapPrimaryKeyAttribute()
        {
        }
    }
}

使い方

モデルには属性でテーブル名とカラム名を指定。

	[MapTable("Test")]
	public class Test
	{
		[MapColumn("id")]
		public int TestId { get; set; }

		[MapColumn("name")]
		public string TestName { get; set; }

		public int NoColumn { get; set; } = 111;
	}

テーブルと結びつけたい場合はMapColumnで指定。結びつけなくて良いものは未指定。
MapTableでテーブル名を指定できるようにしています。

DBへの処理部分(トランザクション無しver)はこちら。パラメータ有りと無しそれぞれ。

			using (SqlConnection connection = new SqlConnection(connectionString))
			{
				connection.Open();
				var items = connection.Query<Test>(null,"select * from test");
				dataGridView1.DataSource = items;
				connection.Close();
			}


//パラメータあり
			using (SqlConnection connection = new SqlConnection(connectionString))
			{
				connection.Open();

				var items = connection.Query<Test>(null,"select * from test where id = @id",Mapper.CreateParams(("@id","1")));
				dataGridView1.DataSource = items;
				connection.Close();
			}

パラメータはList<SqlParameter>型でも渡せるので、処理の中で柔軟にパラメータを追加していくことも可能です。

トランザクション有りverは

            using (SqlConnection connection = new SqlConnection(connectionString))
            {
                connection.Open();
                var transaction = connection.BeginTransaction();
                try
                {
                    var user = new user();
                    user.Name = name;
                    user.Introduce = introduce;
                    user.UpdatedAt = DateTime.Now;

                    connection.Update<User>(transaction, user);

                    transaction.Commit();
                }
                catch (Exception)
                {
                    transaction.Rollback();
                    throw;
                }
                finally
                {
                    connection.Close();
                }
            }

こんな感じ。
トランザクションは明示的にnullなのか(管理しないのか)を把握できるようにしてます。

速度

今回5万件のデータをDataGridViewに表示するまででテストしてみましたが、自作は自由度がある分Dapperの方が早かったです。しっかりと平均値を出したわけではありませんが、大体Dapperだと70%くらいになります。私の関わるプロダクトであれば、これくらいの差なら柔軟に列名とマッピングできるメリットが上回ります。