diff --git a/osu.Desktop/osu.Desktop.csproj b/osu.Desktop/osu.Desktop.csproj index 3d64cab84e..3a35568f8f 100644 --- a/osu.Desktop/osu.Desktop.csproj +++ b/osu.Desktop/osu.Desktop.csproj @@ -1,4 +1,4 @@ - + net471;netcoreapp2.0 @@ -30,10 +30,10 @@ - + - + \ No newline at end of file diff --git a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs index f60caf2397..586217a05f 100644 --- a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs +++ b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs @@ -12,6 +12,7 @@ using osu.Framework.Platform; using osu.Game.IPC; using osu.Framework.Allocation; using osu.Game.Beatmaps; +using SharpCompress.Archives.Zip; namespace osu.Game.Tests.Beatmaps.IO { @@ -77,8 +78,69 @@ namespace osu.Game.Tests.Beatmaps.IO var manager = osu.Dependencies.Get(); - Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1); - Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1); + Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); + Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + } + finally + { + host.Exit(); + } + } + } + + [Test] + public void TestRollbackOnFailure() + { + //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. + using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestRollbackOnFailure")) + { + try + { + var osu = loadOsu(host); + var manager = osu.Dependencies.Get(); + + int fireCount = 0; + + // ReSharper disable once AccessToModifiedClosure + manager.ItemAdded += _ => fireCount++; + manager.ItemRemoved += _ => fireCount++; + + var imported = loadOszIntoOsu(osu); + + Assert.AreEqual(0, fireCount -= 1); + + imported.Hash += "-changed"; + manager.Update(imported); + + Assert.AreEqual(0, fireCount -= 2); + + var breakTemp = createTemporaryBeatmap(); + + MemoryStream brokenOsu = new MemoryStream(new byte[] { 1, 3, 3, 7 }); + MemoryStream brokenOsz = new MemoryStream(File.ReadAllBytes(breakTemp)); + + File.Delete(breakTemp); + + using (var outStream = File.Open(breakTemp, FileMode.CreateNew)) + using (var zip = ZipArchive.Open(brokenOsz)) + { + zip.AddEntry("broken.osu", brokenOsu, false); + zip.SaveTo(outStream, SharpCompress.Common.CompressionType.Deflate); + } + + Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); + Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count); + + // this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu. + manager.Import(breakTemp); + + // no events should be fired in the case of a rollback. + Assert.AreEqual(0, fireCount); + + Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); + Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count); } finally { @@ -100,18 +162,17 @@ namespace osu.Game.Tests.Beatmaps.IO var imported = loadOszIntoOsu(osu); - //var change = manager.QueryBeatmapSets(_ => true).First(); imported.Hash += "-changed"; manager.Update(imported); var importedSecondTime = loadOszIntoOsu(osu); - // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. Assert.IsTrue(imported.ID != importedSecondTime.ID); Assert.IsTrue(imported.Beatmaps.First().ID < importedSecondTime.Beatmaps.First().ID); - Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1); - Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1); + // only one beatmap will exist as the online set ID matched, causing purging of the first import. + Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); + Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); } finally { @@ -162,8 +223,7 @@ namespace osu.Game.Tests.Beatmaps.IO var osu = loadOsu(host); - var temp = prepareTempCopy(osz_path); - Assert.IsTrue(File.Exists(temp)); + var temp = createTemporaryBeatmap(); var importer = new ArchiveImportIPCChannel(client); if (!importer.ImportAsync(temp).Wait(10000)) @@ -188,8 +248,7 @@ namespace osu.Game.Tests.Beatmaps.IO try { var osu = loadOsu(host); - var temp = prepareTempCopy(osz_path); - Assert.IsTrue(File.Exists(temp), "Temporary file copy never substantiated"); + var temp = createTemporaryBeatmap(); using (File.OpenRead(temp)) osu.Dependencies.Get().Import(temp); ensureLoaded(osu); @@ -203,11 +262,16 @@ namespace osu.Game.Tests.Beatmaps.IO } } - private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu) + private string createTemporaryBeatmap() { - var temp = prepareTempCopy(osz_path); - + var temp = new FileInfo(osz_path).CopyTo(Path.GetTempFileName(), true).FullName; Assert.IsTrue(File.Exists(temp)); + return temp; + } + + private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu, string path = null) + { + var temp = path ?? createTemporaryBeatmap(); var manager = osu.Dependencies.Get(); @@ -219,7 +283,7 @@ namespace osu.Game.Tests.Beatmaps.IO waitForOrAssert(() => !File.Exists(temp), "Temporary file still exists after standard import", 5000); - return imported.FirstOrDefault(); + return imported.LastOrDefault(); } private void deleteBeatmapSet(BeatmapSetInfo imported, OsuGameBase osu) @@ -228,16 +292,10 @@ namespace osu.Game.Tests.Beatmaps.IO manager.Delete(imported); Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 0); - Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1); + Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); Assert.IsTrue(manager.QueryBeatmapSets(_ => true).First().DeletePending); } - private string prepareTempCopy(string path) - { - var temp = Path.GetTempFileName(); - return new FileInfo(path).CopyTo(temp, true).FullName; - } - private OsuGameBase loadOsu(GameHost host) { var osu = new OsuGameBase(); diff --git a/osu.Game/Database/ArchiveModelManager.cs b/osu.Game/Database/ArchiveModelManager.cs index e04559d547..1505ac0549 100644 --- a/osu.Game/Database/ArchiveModelManager.cs +++ b/osu.Game/Database/ArchiveModelManager.cs @@ -56,13 +56,49 @@ namespace osu.Game.Database // ReSharper disable once NotAccessedField.Local (we should keep a reference to this so it is not finalised) private ArchiveImportIPCChannel ipc; + private readonly List cachedEvents = new List(); + + /// + /// Allows delaying of outwards events until an operation is confirmed (at a database level). + /// + private bool delayingEvents; + + /// + /// Begin delaying outwards events. + /// + private void delayEvents() => delayingEvents = true; + + /// + /// Flush delayed events and disable delaying. + /// + /// Whether the flushed events should be performed. + private void flushEvents(bool perform) + { + if (perform) + { + foreach (var a in cachedEvents) + a.Invoke(); + } + + cachedEvents.Clear(); + delayingEvents = false; + } + + private void handleEvent(Action a) + { + if (delayingEvents) + cachedEvents.Add(a); + else + a.Invoke(); + } + protected ArchiveModelManager(Storage storage, IDatabaseContextFactory contextFactory, MutableDatabaseBackedStore modelStore, IIpcHost importHost = null) { ContextFactory = contextFactory; ModelStore = modelStore; - ModelStore.ItemAdded += s => ItemAdded?.Invoke(s); - ModelStore.ItemRemoved += s => ItemRemoved?.Invoke(s); + ModelStore.ItemAdded += s => handleEvent(() => ItemAdded?.Invoke(s)); + ModelStore.ItemRemoved += s => handleEvent(() => ItemRemoved?.Invoke(s)); Files = new FileStore(contextFactory, storage); @@ -138,24 +174,50 @@ namespace osu.Game.Database /// The archive to be imported. public TModel Import(ArchiveReader archive) { - using (ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. + TModel item = null; + delayEvents(); + + try { - // create a new model (don't yet add to database) - var item = CreateModel(archive); + using (var write = ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. + { + try + { + if (!write.IsTransactionLeader) throw new InvalidOperationException($"Ensure there is no parent transaction so errors can correctly be handled by {this}"); - var existing = CheckForExisting(item); + // create a new model (don't yet add to database) + item = CreateModel(archive); - if (existing != null) return existing; + var existing = CheckForExisting(item); - item.Files = createFileInfos(archive, Files); + if (existing != null) return existing; - Populate(item, archive); + item.Files = createFileInfos(archive, Files); - // import to store - ModelStore.Add(item); + Populate(item, archive); - return item; + // import to store + ModelStore.Add(item); + } + catch (Exception e) + { + write.Errors.Add(e); + throw; + } + } } + catch (Exception e) + { + Logger.Error(e, $"Import of {archive.Name} failed and has been rolled back.", LoggingTarget.Database); + item = null; + } + finally + { + // we only want to flush events after we've confirmed the write context didn't have any errors. + flushEvents(item != null); + } + + return item; } /// @@ -178,12 +240,8 @@ namespace osu.Game.Database /// The item to delete. public void Delete(TModel item) { - using (var usage = ContextFactory.GetForWrite()) + using (ContextFactory.GetForWrite()) { - var context = usage.Context; - - context.ChangeTracker.AutoDetectChangesEnabled = false; - // re-fetch the model on the import context. var foundModel = queryModel().Include(s => s.Files).ThenInclude(f => f.FileInfo).First(s => s.ID == item.ID); @@ -191,8 +249,6 @@ namespace osu.Game.Database if (ModelStore.Delete(foundModel)) Files.Dereference(foundModel.Files.Select(f => f.FileInfo).ToArray()); - - context.ChangeTracker.AutoDetectChangesEnabled = true; } } diff --git a/osu.Game/Database/DatabaseContextFactory.cs b/osu.Game/Database/DatabaseContextFactory.cs index 71960303b5..a1d371f431 100644 --- a/osu.Game/Database/DatabaseContextFactory.cs +++ b/osu.Game/Database/DatabaseContextFactory.cs @@ -1,7 +1,9 @@ // Copyright (c) 2007-2018 ppy Pty Ltd . // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE +using System.Linq; using System.Threading; +using Microsoft.EntityFrameworkCore.Storage; using osu.Framework.Platform; namespace osu.Game.Database @@ -17,8 +19,12 @@ namespace osu.Game.Database private readonly object writeLock = new object(); private bool currentWriteDidWrite; + private bool currentWriteDidError; + private int currentWriteUsages; + private IDbContextTransaction currentWriteTransaction; + public DatabaseContextFactory(GameHost host) { this.host = host; @@ -35,14 +41,25 @@ namespace osu.Game.Database /// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context). /// This method may block if a write is already active on a different thread. /// + /// Whether to start a transaction for this write. /// A usage containing a usable context. - public DatabaseWriteUsage GetForWrite() + public DatabaseWriteUsage GetForWrite(bool withTransaction = true) { Monitor.Enter(writeLock); + if (currentWriteTransaction == null && withTransaction) + { + // this mitigates the fact that changes on tracked entities will not be rolled back with the transaction by ensuring write operations are always executed in isolated contexts. + // if this results in sub-optimal efficiency, we may need to look into removing Database-level transactions in favour of running SaveChanges where we currently commit the transaction. + if (threadContexts.IsValueCreated) + recycleThreadContexts(); + + currentWriteTransaction = threadContexts.Value.Database.BeginTransaction(); + } + Interlocked.Increment(ref currentWriteUsages); - return new DatabaseWriteUsage(threadContexts.Value, usageCompleted); + return new DatabaseWriteUsage(threadContexts.Value, usageCompleted) { IsTransactionLeader = currentWriteTransaction != null && currentWriteUsages == 1 }; } private void usageCompleted(DatabaseWriteUsage usage) @@ -52,18 +69,27 @@ namespace osu.Game.Database try { currentWriteDidWrite |= usage.PerformedWrite; + currentWriteDidError |= usage.Errors.Any(); - if (usages > 0) return; - - if (currentWriteDidWrite) + if (usages == 0) { - // explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged). - usage.Context.Dispose(); + if (currentWriteDidError) + currentWriteTransaction?.Rollback(); + else + currentWriteTransaction?.Commit(); + if (currentWriteDidWrite || currentWriteDidError) + { + // explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged). + usage.Context.Dispose(); + + // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. + recycleThreadContexts(); + } + + currentWriteTransaction = null; currentWriteDidWrite = false; - - // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. - recycleThreadContexts(); + currentWriteDidError = false; } } finally diff --git a/osu.Game/Database/DatabaseWriteUsage.cs b/osu.Game/Database/DatabaseWriteUsage.cs index 7858c1a0d1..64ab24e824 100644 --- a/osu.Game/Database/DatabaseWriteUsage.cs +++ b/osu.Game/Database/DatabaseWriteUsage.cs @@ -2,34 +2,50 @@ // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE using System; -using Microsoft.EntityFrameworkCore.Storage; +using System.Collections.Generic; namespace osu.Game.Database { public class DatabaseWriteUsage : IDisposable { public readonly OsuDbContext Context; - private readonly IDbContextTransaction transaction; private readonly Action usageCompleted; public DatabaseWriteUsage(OsuDbContext context, Action onCompleted) { Context = context; - transaction = Context.BeginTransaction(); usageCompleted = onCompleted; } public bool PerformedWrite { get; private set; } private bool isDisposed; + public List Errors = new List(); + + /// + /// Whether this write usage will commit a transaction on completion. + /// If false, there is a parent usage responsible for transaction commit. + /// + public bool IsTransactionLeader = false; protected void Dispose(bool disposing) { if (isDisposed) return; isDisposed = true; - PerformedWrite |= Context.SaveChanges(transaction) > 0; - usageCompleted?.Invoke(this); + try + { + PerformedWrite |= Context.SaveChanges() > 0; + } + catch (Exception e) + { + Errors.Add(e); + throw; + } + finally + { + usageCompleted?.Invoke(this); + } } public void Dispose() diff --git a/osu.Game/Database/IDatabaseContextFactory.cs b/osu.Game/Database/IDatabaseContextFactory.cs index 372e1770e4..d38d15b252 100644 --- a/osu.Game/Database/IDatabaseContextFactory.cs +++ b/osu.Game/Database/IDatabaseContextFactory.cs @@ -14,7 +14,8 @@ namespace osu.Game.Database /// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context). /// This method may block if a write is already active on a different thread. /// + /// Whether to start a transaction for this write. /// A usage containing a usable context. - DatabaseWriteUsage GetForWrite(); + DatabaseWriteUsage GetForWrite(bool withTransaction = true); } } diff --git a/osu.Game/Database/MutableDatabaseBackedStore.cs b/osu.Game/Database/MutableDatabaseBackedStore.cs index 8569d81f01..69a1f57cc4 100644 --- a/osu.Game/Database/MutableDatabaseBackedStore.cs +++ b/osu.Game/Database/MutableDatabaseBackedStore.cs @@ -50,11 +50,10 @@ namespace osu.Game.Database /// The item to update. public void Update(T item) { - ItemRemoved?.Invoke(item); - using (var usage = ContextFactory.GetForWrite()) usage.Context.Update(item); + ItemRemoved?.Invoke(item); ItemAdded?.Invoke(item); } diff --git a/osu.Game/Database/OsuDbContext.cs b/osu.Game/Database/OsuDbContext.cs index 1979ce3648..0ae197d62d 100644 --- a/osu.Game/Database/OsuDbContext.cs +++ b/osu.Game/Database/OsuDbContext.cs @@ -3,7 +3,6 @@ using System; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.Extensions.Logging; using osu.Framework.Logging; @@ -104,19 +103,6 @@ namespace osu.Game.Database modelBuilder.Entity().HasOne(b => b.BaseDifficulty); } - public IDbContextTransaction BeginTransaction() - { - // return Database.BeginTransaction(); - return null; - } - - public int SaveChanges(IDbContextTransaction transaction = null) - { - var ret = base.SaveChanges(); - if (ret > 0) transaction?.Commit(); - return ret; - } - private class OsuDbLoggerFactory : ILoggerFactory { #region Disposal diff --git a/osu.Game/Database/SingletonContextFactory.cs b/osu.Game/Database/SingletonContextFactory.cs index 74951e8433..ce3fbf6881 100644 --- a/osu.Game/Database/SingletonContextFactory.cs +++ b/osu.Game/Database/SingletonContextFactory.cs @@ -14,6 +14,6 @@ namespace osu.Game.Database public OsuDbContext Get() => context; - public DatabaseWriteUsage GetForWrite() => new DatabaseWriteUsage(context, null); + public DatabaseWriteUsage GetForWrite(bool withTransaction = true) => new DatabaseWriteUsage(context, null); } } diff --git a/osu.Game/OsuGameBase.cs b/osu.Game/OsuGameBase.cs index a3a081d6d1..b9d32a6322 100644 --- a/osu.Game/OsuGameBase.cs +++ b/osu.Game/OsuGameBase.cs @@ -208,7 +208,7 @@ namespace osu.Game { try { - using (var db = contextFactory.GetForWrite()) + using (var db = contextFactory.GetForWrite(false)) db.Context.Migrate(); } catch (MigrationFailedException e) @@ -220,7 +220,7 @@ namespace osu.Game contextFactory.ResetDatabase(); Logger.Log("Database purged successfully.", LoggingTarget.Database, LogLevel.Important); - using (var db = contextFactory.GetForWrite()) + using (var db = contextFactory.GetForWrite(false)) db.Context.Migrate(); } } diff --git a/osu.Game/osu.Game.csproj b/osu.Game/osu.Game.csproj index 1a75f1979a..afb656a260 100644 --- a/osu.Game/osu.Game.csproj +++ b/osu.Game/osu.Game.csproj @@ -17,7 +17,7 @@ - +