diff --git a/osu.Desktop/osu.Desktop.csproj b/osu.Desktop/osu.Desktop.csproj
index 3d64cab84e..3a35568f8f 100644
--- a/osu.Desktop/osu.Desktop.csproj
+++ b/osu.Desktop/osu.Desktop.csproj
@@ -1,4 +1,4 @@
-
+
net471;netcoreapp2.0
@@ -30,10 +30,10 @@
-
+
-
+
\ No newline at end of file
diff --git a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs
index f60caf2397..586217a05f 100644
--- a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs
+++ b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs
@@ -12,6 +12,7 @@ using osu.Framework.Platform;
using osu.Game.IPC;
using osu.Framework.Allocation;
using osu.Game.Beatmaps;
+using SharpCompress.Archives.Zip;
namespace osu.Game.Tests.Beatmaps.IO
{
@@ -77,8 +78,69 @@ namespace osu.Game.Tests.Beatmaps.IO
var manager = osu.Dependencies.Get();
- Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1);
- Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
+ Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
+ Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
+ }
+ finally
+ {
+ host.Exit();
+ }
+ }
+ }
+
+ [Test]
+ public void TestRollbackOnFailure()
+ {
+ //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here.
+ using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestRollbackOnFailure"))
+ {
+ try
+ {
+ var osu = loadOsu(host);
+ var manager = osu.Dependencies.Get();
+
+ int fireCount = 0;
+
+ // ReSharper disable once AccessToModifiedClosure
+ manager.ItemAdded += _ => fireCount++;
+ manager.ItemRemoved += _ => fireCount++;
+
+ var imported = loadOszIntoOsu(osu);
+
+ Assert.AreEqual(0, fireCount -= 1);
+
+ imported.Hash += "-changed";
+ manager.Update(imported);
+
+ Assert.AreEqual(0, fireCount -= 2);
+
+ var breakTemp = createTemporaryBeatmap();
+
+ MemoryStream brokenOsu = new MemoryStream(new byte[] { 1, 3, 3, 7 });
+ MemoryStream brokenOsz = new MemoryStream(File.ReadAllBytes(breakTemp));
+
+ File.Delete(breakTemp);
+
+ using (var outStream = File.Open(breakTemp, FileMode.CreateNew))
+ using (var zip = ZipArchive.Open(brokenOsz))
+ {
+ zip.AddEntry("broken.osu", brokenOsu, false);
+ zip.SaveTo(outStream, SharpCompress.Common.CompressionType.Deflate);
+ }
+
+ Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
+ Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
+ Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count);
+
+ // this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu.
+ manager.Import(breakTemp);
+
+ // no events should be fired in the case of a rollback.
+ Assert.AreEqual(0, fireCount);
+
+ Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
+ Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
+ Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count);
}
finally
{
@@ -100,18 +162,17 @@ namespace osu.Game.Tests.Beatmaps.IO
var imported = loadOszIntoOsu(osu);
- //var change = manager.QueryBeatmapSets(_ => true).First();
imported.Hash += "-changed";
manager.Update(imported);
var importedSecondTime = loadOszIntoOsu(osu);
- // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash.
Assert.IsTrue(imported.ID != importedSecondTime.ID);
Assert.IsTrue(imported.Beatmaps.First().ID < importedSecondTime.Beatmaps.First().ID);
- Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1);
- Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
+ // only one beatmap will exist as the online set ID matched, causing purging of the first import.
+ Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
+ Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
}
finally
{
@@ -162,8 +223,7 @@ namespace osu.Game.Tests.Beatmaps.IO
var osu = loadOsu(host);
- var temp = prepareTempCopy(osz_path);
- Assert.IsTrue(File.Exists(temp));
+ var temp = createTemporaryBeatmap();
var importer = new ArchiveImportIPCChannel(client);
if (!importer.ImportAsync(temp).Wait(10000))
@@ -188,8 +248,7 @@ namespace osu.Game.Tests.Beatmaps.IO
try
{
var osu = loadOsu(host);
- var temp = prepareTempCopy(osz_path);
- Assert.IsTrue(File.Exists(temp), "Temporary file copy never substantiated");
+ var temp = createTemporaryBeatmap();
using (File.OpenRead(temp))
osu.Dependencies.Get().Import(temp);
ensureLoaded(osu);
@@ -203,11 +262,16 @@ namespace osu.Game.Tests.Beatmaps.IO
}
}
- private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu)
+ private string createTemporaryBeatmap()
{
- var temp = prepareTempCopy(osz_path);
-
+ var temp = new FileInfo(osz_path).CopyTo(Path.GetTempFileName(), true).FullName;
Assert.IsTrue(File.Exists(temp));
+ return temp;
+ }
+
+ private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu, string path = null)
+ {
+ var temp = path ?? createTemporaryBeatmap();
var manager = osu.Dependencies.Get();
@@ -219,7 +283,7 @@ namespace osu.Game.Tests.Beatmaps.IO
waitForOrAssert(() => !File.Exists(temp), "Temporary file still exists after standard import", 5000);
- return imported.FirstOrDefault();
+ return imported.LastOrDefault();
}
private void deleteBeatmapSet(BeatmapSetInfo imported, OsuGameBase osu)
@@ -228,16 +292,10 @@ namespace osu.Game.Tests.Beatmaps.IO
manager.Delete(imported);
Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 0);
- Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
+ Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
Assert.IsTrue(manager.QueryBeatmapSets(_ => true).First().DeletePending);
}
- private string prepareTempCopy(string path)
- {
- var temp = Path.GetTempFileName();
- return new FileInfo(path).CopyTo(temp, true).FullName;
- }
-
private OsuGameBase loadOsu(GameHost host)
{
var osu = new OsuGameBase();
diff --git a/osu.Game/Database/ArchiveModelManager.cs b/osu.Game/Database/ArchiveModelManager.cs
index e04559d547..1505ac0549 100644
--- a/osu.Game/Database/ArchiveModelManager.cs
+++ b/osu.Game/Database/ArchiveModelManager.cs
@@ -56,13 +56,49 @@ namespace osu.Game.Database
// ReSharper disable once NotAccessedField.Local (we should keep a reference to this so it is not finalised)
private ArchiveImportIPCChannel ipc;
+ private readonly List cachedEvents = new List();
+
+ ///
+ /// Allows delaying of outwards events until an operation is confirmed (at a database level).
+ ///
+ private bool delayingEvents;
+
+ ///
+ /// Begin delaying outwards events.
+ ///
+ private void delayEvents() => delayingEvents = true;
+
+ ///
+ /// Flush delayed events and disable delaying.
+ ///
+ /// Whether the flushed events should be performed.
+ private void flushEvents(bool perform)
+ {
+ if (perform)
+ {
+ foreach (var a in cachedEvents)
+ a.Invoke();
+ }
+
+ cachedEvents.Clear();
+ delayingEvents = false;
+ }
+
+ private void handleEvent(Action a)
+ {
+ if (delayingEvents)
+ cachedEvents.Add(a);
+ else
+ a.Invoke();
+ }
+
protected ArchiveModelManager(Storage storage, IDatabaseContextFactory contextFactory, MutableDatabaseBackedStore modelStore, IIpcHost importHost = null)
{
ContextFactory = contextFactory;
ModelStore = modelStore;
- ModelStore.ItemAdded += s => ItemAdded?.Invoke(s);
- ModelStore.ItemRemoved += s => ItemRemoved?.Invoke(s);
+ ModelStore.ItemAdded += s => handleEvent(() => ItemAdded?.Invoke(s));
+ ModelStore.ItemRemoved += s => handleEvent(() => ItemRemoved?.Invoke(s));
Files = new FileStore(contextFactory, storage);
@@ -138,24 +174,50 @@ namespace osu.Game.Database
/// The archive to be imported.
public TModel Import(ArchiveReader archive)
{
- using (ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
+ TModel item = null;
+ delayEvents();
+
+ try
{
- // create a new model (don't yet add to database)
- var item = CreateModel(archive);
+ using (var write = ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
+ {
+ try
+ {
+ if (!write.IsTransactionLeader) throw new InvalidOperationException($"Ensure there is no parent transaction so errors can correctly be handled by {this}");
- var existing = CheckForExisting(item);
+ // create a new model (don't yet add to database)
+ item = CreateModel(archive);
- if (existing != null) return existing;
+ var existing = CheckForExisting(item);
- item.Files = createFileInfos(archive, Files);
+ if (existing != null) return existing;
- Populate(item, archive);
+ item.Files = createFileInfos(archive, Files);
- // import to store
- ModelStore.Add(item);
+ Populate(item, archive);
- return item;
+ // import to store
+ ModelStore.Add(item);
+ }
+ catch (Exception e)
+ {
+ write.Errors.Add(e);
+ throw;
+ }
+ }
}
+ catch (Exception e)
+ {
+ Logger.Error(e, $"Import of {archive.Name} failed and has been rolled back.", LoggingTarget.Database);
+ item = null;
+ }
+ finally
+ {
+ // we only want to flush events after we've confirmed the write context didn't have any errors.
+ flushEvents(item != null);
+ }
+
+ return item;
}
///
@@ -178,12 +240,8 @@ namespace osu.Game.Database
/// The item to delete.
public void Delete(TModel item)
{
- using (var usage = ContextFactory.GetForWrite())
+ using (ContextFactory.GetForWrite())
{
- var context = usage.Context;
-
- context.ChangeTracker.AutoDetectChangesEnabled = false;
-
// re-fetch the model on the import context.
var foundModel = queryModel().Include(s => s.Files).ThenInclude(f => f.FileInfo).First(s => s.ID == item.ID);
@@ -191,8 +249,6 @@ namespace osu.Game.Database
if (ModelStore.Delete(foundModel))
Files.Dereference(foundModel.Files.Select(f => f.FileInfo).ToArray());
-
- context.ChangeTracker.AutoDetectChangesEnabled = true;
}
}
diff --git a/osu.Game/Database/DatabaseContextFactory.cs b/osu.Game/Database/DatabaseContextFactory.cs
index 71960303b5..a1d371f431 100644
--- a/osu.Game/Database/DatabaseContextFactory.cs
+++ b/osu.Game/Database/DatabaseContextFactory.cs
@@ -1,7 +1,9 @@
// Copyright (c) 2007-2018 ppy Pty Ltd .
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
+using System.Linq;
using System.Threading;
+using Microsoft.EntityFrameworkCore.Storage;
using osu.Framework.Platform;
namespace osu.Game.Database
@@ -17,8 +19,12 @@ namespace osu.Game.Database
private readonly object writeLock = new object();
private bool currentWriteDidWrite;
+ private bool currentWriteDidError;
+
private int currentWriteUsages;
+ private IDbContextTransaction currentWriteTransaction;
+
public DatabaseContextFactory(GameHost host)
{
this.host = host;
@@ -35,14 +41,25 @@ namespace osu.Game.Database
/// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context).
/// This method may block if a write is already active on a different thread.
///
+ /// Whether to start a transaction for this write.
/// A usage containing a usable context.
- public DatabaseWriteUsage GetForWrite()
+ public DatabaseWriteUsage GetForWrite(bool withTransaction = true)
{
Monitor.Enter(writeLock);
+ if (currentWriteTransaction == null && withTransaction)
+ {
+ // this mitigates the fact that changes on tracked entities will not be rolled back with the transaction by ensuring write operations are always executed in isolated contexts.
+ // if this results in sub-optimal efficiency, we may need to look into removing Database-level transactions in favour of running SaveChanges where we currently commit the transaction.
+ if (threadContexts.IsValueCreated)
+ recycleThreadContexts();
+
+ currentWriteTransaction = threadContexts.Value.Database.BeginTransaction();
+ }
+
Interlocked.Increment(ref currentWriteUsages);
- return new DatabaseWriteUsage(threadContexts.Value, usageCompleted);
+ return new DatabaseWriteUsage(threadContexts.Value, usageCompleted) { IsTransactionLeader = currentWriteTransaction != null && currentWriteUsages == 1 };
}
private void usageCompleted(DatabaseWriteUsage usage)
@@ -52,18 +69,27 @@ namespace osu.Game.Database
try
{
currentWriteDidWrite |= usage.PerformedWrite;
+ currentWriteDidError |= usage.Errors.Any();
- if (usages > 0) return;
-
- if (currentWriteDidWrite)
+ if (usages == 0)
{
- // explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged).
- usage.Context.Dispose();
+ if (currentWriteDidError)
+ currentWriteTransaction?.Rollback();
+ else
+ currentWriteTransaction?.Commit();
+ if (currentWriteDidWrite || currentWriteDidError)
+ {
+ // explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged).
+ usage.Context.Dispose();
+
+ // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
+ recycleThreadContexts();
+ }
+
+ currentWriteTransaction = null;
currentWriteDidWrite = false;
-
- // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
- recycleThreadContexts();
+ currentWriteDidError = false;
}
}
finally
diff --git a/osu.Game/Database/DatabaseWriteUsage.cs b/osu.Game/Database/DatabaseWriteUsage.cs
index 7858c1a0d1..64ab24e824 100644
--- a/osu.Game/Database/DatabaseWriteUsage.cs
+++ b/osu.Game/Database/DatabaseWriteUsage.cs
@@ -2,34 +2,50 @@
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
using System;
-using Microsoft.EntityFrameworkCore.Storage;
+using System.Collections.Generic;
namespace osu.Game.Database
{
public class DatabaseWriteUsage : IDisposable
{
public readonly OsuDbContext Context;
- private readonly IDbContextTransaction transaction;
private readonly Action usageCompleted;
public DatabaseWriteUsage(OsuDbContext context, Action onCompleted)
{
Context = context;
- transaction = Context.BeginTransaction();
usageCompleted = onCompleted;
}
public bool PerformedWrite { get; private set; }
private bool isDisposed;
+ public List Errors = new List();
+
+ ///
+ /// Whether this write usage will commit a transaction on completion.
+ /// If false, there is a parent usage responsible for transaction commit.
+ ///
+ public bool IsTransactionLeader = false;
protected void Dispose(bool disposing)
{
if (isDisposed) return;
isDisposed = true;
- PerformedWrite |= Context.SaveChanges(transaction) > 0;
- usageCompleted?.Invoke(this);
+ try
+ {
+ PerformedWrite |= Context.SaveChanges() > 0;
+ }
+ catch (Exception e)
+ {
+ Errors.Add(e);
+ throw;
+ }
+ finally
+ {
+ usageCompleted?.Invoke(this);
+ }
}
public void Dispose()
diff --git a/osu.Game/Database/IDatabaseContextFactory.cs b/osu.Game/Database/IDatabaseContextFactory.cs
index 372e1770e4..d38d15b252 100644
--- a/osu.Game/Database/IDatabaseContextFactory.cs
+++ b/osu.Game/Database/IDatabaseContextFactory.cs
@@ -14,7 +14,8 @@ namespace osu.Game.Database
/// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context).
/// This method may block if a write is already active on a different thread.
///
+ /// Whether to start a transaction for this write.
/// A usage containing a usable context.
- DatabaseWriteUsage GetForWrite();
+ DatabaseWriteUsage GetForWrite(bool withTransaction = true);
}
}
diff --git a/osu.Game/Database/MutableDatabaseBackedStore.cs b/osu.Game/Database/MutableDatabaseBackedStore.cs
index 8569d81f01..69a1f57cc4 100644
--- a/osu.Game/Database/MutableDatabaseBackedStore.cs
+++ b/osu.Game/Database/MutableDatabaseBackedStore.cs
@@ -50,11 +50,10 @@ namespace osu.Game.Database
/// The item to update.
public void Update(T item)
{
- ItemRemoved?.Invoke(item);
-
using (var usage = ContextFactory.GetForWrite())
usage.Context.Update(item);
+ ItemRemoved?.Invoke(item);
ItemAdded?.Invoke(item);
}
diff --git a/osu.Game/Database/OsuDbContext.cs b/osu.Game/Database/OsuDbContext.cs
index 1979ce3648..0ae197d62d 100644
--- a/osu.Game/Database/OsuDbContext.cs
+++ b/osu.Game/Database/OsuDbContext.cs
@@ -3,7 +3,6 @@
using System;
using Microsoft.EntityFrameworkCore;
-using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.Extensions.Logging;
using osu.Framework.Logging;
@@ -104,19 +103,6 @@ namespace osu.Game.Database
modelBuilder.Entity().HasOne(b => b.BaseDifficulty);
}
- public IDbContextTransaction BeginTransaction()
- {
- // return Database.BeginTransaction();
- return null;
- }
-
- public int SaveChanges(IDbContextTransaction transaction = null)
- {
- var ret = base.SaveChanges();
- if (ret > 0) transaction?.Commit();
- return ret;
- }
-
private class OsuDbLoggerFactory : ILoggerFactory
{
#region Disposal
diff --git a/osu.Game/Database/SingletonContextFactory.cs b/osu.Game/Database/SingletonContextFactory.cs
index 74951e8433..ce3fbf6881 100644
--- a/osu.Game/Database/SingletonContextFactory.cs
+++ b/osu.Game/Database/SingletonContextFactory.cs
@@ -14,6 +14,6 @@ namespace osu.Game.Database
public OsuDbContext Get() => context;
- public DatabaseWriteUsage GetForWrite() => new DatabaseWriteUsage(context, null);
+ public DatabaseWriteUsage GetForWrite(bool withTransaction = true) => new DatabaseWriteUsage(context, null);
}
}
diff --git a/osu.Game/OsuGameBase.cs b/osu.Game/OsuGameBase.cs
index a3a081d6d1..b9d32a6322 100644
--- a/osu.Game/OsuGameBase.cs
+++ b/osu.Game/OsuGameBase.cs
@@ -208,7 +208,7 @@ namespace osu.Game
{
try
{
- using (var db = contextFactory.GetForWrite())
+ using (var db = contextFactory.GetForWrite(false))
db.Context.Migrate();
}
catch (MigrationFailedException e)
@@ -220,7 +220,7 @@ namespace osu.Game
contextFactory.ResetDatabase();
Logger.Log("Database purged successfully.", LoggingTarget.Database, LogLevel.Important);
- using (var db = contextFactory.GetForWrite())
+ using (var db = contextFactory.GetForWrite(false))
db.Context.Migrate();
}
}
diff --git a/osu.Game/osu.Game.csproj b/osu.Game/osu.Game.csproj
index 1a75f1979a..afb656a260 100644
--- a/osu.Game/osu.Game.csproj
+++ b/osu.Game/osu.Game.csproj
@@ -17,7 +17,7 @@
-
+