Files
DbMigrate/SqlDatabase.cs

229 lines
10 KiB
C#

using Dapper;
using System;
using System.Collections.Generic;
using System.Data.SQLite;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
namespace DbMigrate {
public class SqlDatabase {
public string ConnectionString { get; set; }
public string SqlScript { get; set; }
public List<SqlTable> Tables { get; set; } = new List<SqlTable>();
public SqlDatabase() {
}
public void Connect(string connectionString) {
ConnectionString = connectionString;
}
public void LoadSql(string sql) {
SqlScript = sql;
Tables = ParseTablesFromSql(sql).ToList();
}
public void LoadSqlFromFile(string fileName) {
if (!File.Exists(fileName)) {
throw new FileNotFoundException("SQL file '" + fileName + "' was not found and could not be loaded.");
}
string sql = File.ReadAllText(fileName);
LoadSql(sql);
}
public bool ContainsTable(string tableName) {
return Tables.Count(f => f.TableName.ToLower() == tableName.ToLower()) > 0;
}
public SqlTable this[string tableName] {
get {
return Tables.FirstOrDefault(f => f.TableName.ToLower() == tableName.ToLower());
}
}
public IEnumerable<SqlTable> ParseTablesFromSql(string sql) {
SqlTable table = null;
StringBuilder sb = new StringBuilder();
Dictionary<string, List<string>> indexes = new Dictionary<string, List<string>>();
Dictionary<string, List<string>> triggers = new Dictionary<string, List<string>>();
string currentElementType = "";
foreach (string line in Regex.Split(sql, "\\r\\n")) {
if (string.IsNullOrEmpty(line) || line.StartsWith("--")) {
continue;
}
string trimmedLine = line;
if (currentElementType == "trigger" || currentElementType == "index") {
trimmedLine = line.Trim();
} else {
trimmedLine = line.StartsWith(" ") ? "&&" + line.Trim() : line;
trimmedLine = Regex.Replace(trimmedLine, @"\s+", " ");
trimmedLine = trimmedLine.Replace("&&", "\t");
}
if (trimmedLine.ToUpper().StartsWith("CREATE TABLE ")) {
// Start a new table
table = new SqlTable();
sb = new StringBuilder();
sb.AppendLine(trimmedLine);
currentElementType = "table";
continue;
}
if (trimmedLine.ToUpper().StartsWith("CREATE INDEX ")) {
var matches = Regex.Match(trimmedLine, "CREATE INDEX( IF NOT EXISTS)? (\\w*) ON (\\w*)");
if (matches.Success) {
sb = new StringBuilder();
sb.AppendLine(trimmedLine);
currentElementType = "index";
if (!indexes.ContainsKey(matches.Groups[2].Value.Trim())) {
indexes[matches.Groups[2].Value.Trim()] = new List<string>();
}
continue;
}
}
if (trimmedLine.ToUpper().StartsWith("CREATE TRIGGER ")) {
sb = new StringBuilder();
sb.AppendLine(trimmedLine);
currentElementType = "trigger";
continue;
}
// The element has concluded (may occur on the same line as above)
if (trimmedLine.EndsWith(");") || trimmedLine.EndsWith("END;")) {
sb.AppendLine(trimmedLine);
if (currentElementType == "table" && table != null) {
SqlScript += sb.ToString() + Environment.NewLine + Environment.NewLine;
table.ParseSql(sb.ToString());
Tables.Add(table);
currentElementType = "";
yield return table;
} else if (currentElementType == "index") {
string indexSql = sb.ToString().Replace(Environment.NewLine, " ");
SqlScript += indexSql + Environment.NewLine + Environment.NewLine;
var matches = Regex.Match(indexSql, "CREATE INDEX( IF NOT EXISTS)? (\\w*) ON (\\w*)");
if (matches.Success) {
if (!indexes.ContainsKey(matches.Groups[3].Value.Trim())) {
indexes[matches.Groups[3].Value.Trim()] = new List<string>();
}
string indexName = matches.Groups[2].Value.Trim();
string tableName = matches.Groups[3].Value.Trim();
indexes[tableName].Add(indexName + ";" + indexSql);
}
currentElementType = "";
} else if (currentElementType == "trigger") {
string triggerSql = sb.ToString().Replace(Environment.NewLine, " ");
SqlScript += triggerSql + Environment.NewLine + Environment.NewLine;
var matches = Regex.Match(triggerSql, "CREATE TRIGGER( IF NOT EXISTS)? (\\w*) ON (\\w*)");
if (matches.Success) {
if (!triggers.ContainsKey(matches.Groups[3].Value.Trim())) {
triggers[matches.Groups[3].Value.Trim()] = new List<string>();
}
string triggerName = matches.Groups[2].Value.Trim();
string tableName = matches.Groups[3].Value.Trim();
triggers[tableName].Add(triggerName + ";" + triggerSql);
}
currentElementType = "";
}
continue;
}
sb.AppendLine(trimmedLine);
}
foreach (string index in indexes.Keys) {
table = Tables.FirstOrDefault(t => t.TableName == index);
if (table != null) {
foreach (string indexSql in indexes[index]) {
var parts = indexSql.Split(new char[] { ';' }, 2);
if (parts.Length == 2) {
table.Indexes[parts[0]] = parts[1];
}
}
}
}
foreach (string trigger in triggers.Keys) {
table = Tables.FirstOrDefault(t => t.TableName == trigger);
if (table != null) {
foreach (string triggerSql in triggers[trigger]) {
var parts = triggerSql.Split(new char[] { ';' }, 2);
if (parts.Length == 2) {
table.Triggers[parts[0]] = parts[1];
}
}
}
}
}
public async Task<string> BuildSql(string dbConnectionString, bool includeIfNotExist = false) {
using (SQLiteConnection cn = new SQLiteConnection(dbConnectionString)) {
string sql = "";
IEnumerable<SqliteTableDefinition> TableDefs = await cn.QueryAsync<SqliteTableDefinition>("select * from sqlite_master");
foreach (SqliteTableDefinition table in TableDefs.Where(f => f.type == "table").OrderBy(f => f.tbl_name)) {
if (table.tbl_name == "sqlite_sequence") { continue; }
Match m = Regex.Match(table.sql, "CREATE TABLE \\S+ \\((.*)\\)", RegexOptions.Singleline);
if (!m.Success) {
Trace.TraceWarning("Unable to match regex on table " + table.name);
continue;
}
string tableSql = "";
int startIndex = m.Groups[1].Index;
int length = m.Groups[1].Length;
string columns = Regex.Replace(m.Groups[1].Value, "\\s{2,}", " ");
columns = Regex.Replace(columns.Replace(", ", ",").Replace(",\n", ","), ",(?!\\d+\\))", ",\r\n\t");
tableSql += "-- BEGIN TABLE " + table.tbl_name + " --\r\n";
tableSql += table.sql.Substring(0, startIndex) + "\r\n\t" +
columns.Trim() + "\r\n" +
table.sql.Substring(startIndex + length) + ";\r\n";
List<SqliteTableDefinition> indexes = TableDefs.Where(f => f.type == "index" && f.tbl_name == table.tbl_name && !string.IsNullOrEmpty(f.sql)).ToList();
if (indexes.Count > 0) {
tableSql += "\r\n-- INDEXES --\r\n";
foreach (var index in indexes) {
if (string.IsNullOrEmpty(index.sql)) { continue; }
tableSql += index.sql.Replace(Environment.NewLine, " ") + ";\r\n";
}
}
List<SqliteTableDefinition> triggers = TableDefs.Where(f => f.type == "trigger" && f.tbl_name == table.tbl_name && !string.IsNullOrEmpty(f.sql)).ToList();
if (triggers.Count > 0) {
tableSql += "\r\n-- TRIGGERS --\r\n";
foreach (var trigger in triggers) {
if (string.IsNullOrEmpty(trigger.sql)) { continue; }
tableSql += trigger.sql.Replace(Environment.NewLine, " ") + ";\r\n";
}
}
tableSql += "-- END TABLE " + table.tbl_name + " --\r\n\r\n";
Tables.Add(new SqlTable(tableSql));
sql += tableSql;
}
return sql;
}
}
}
}