diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs
index 4ec153c78f..41ea293938 100644
--- a/osu.Game/Beatmaps/BeatmapManager.cs
+++ b/osu.Game/Beatmaps/BeatmapManager.cs
@@ -172,7 +172,7 @@ namespace osu.Game.Beatmaps
/// The beatmap to be imported.
public BeatmapSetInfo Import(ArchiveReader archive)
{
- using ( contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
+ using (contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
{
// create a new set info (don't yet add to database)
var beatmapSet = createBeatmapSetInfo(archive);
@@ -181,7 +181,7 @@ namespace osu.Game.Beatmaps
var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash);
if (existingHashMatch != null)
{
- undelete(existingHashMatch);
+ Undelete(existingHashMatch);
return existingHashMatch;
}
@@ -315,9 +315,9 @@ namespace osu.Game.Beatmaps
/// The beatmap set to delete.
public void Delete(BeatmapSetInfo beatmapSet)
{
- using (var db = contextFactory.GetForWrite())
+ using (var usage = contextFactory.GetForWrite())
{
- var context = db.Context;
+ var context = usage.Context;
context.ChangeTracker.AutoDetectChangesEnabled = false;
@@ -378,11 +378,16 @@ namespace osu.Game.Beatmaps
if (beatmapSet.Protected)
return;
- using (var db = contextFactory.GetForWrite())
+ using (var usage = contextFactory.GetForWrite())
{
- db.Context.ChangeTracker.AutoDetectChangesEnabled = false;
- undelete(beatmapSet);
- db.Context.ChangeTracker.AutoDetectChangesEnabled = true;
+ usage.Context.ChangeTracker.AutoDetectChangesEnabled = false;
+
+ if (!beatmaps.Undelete(beatmapSet)) return;
+
+ if (!beatmapSet.Protected)
+ files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
+
+ usage.Context.ChangeTracker.AutoDetectChangesEnabled = true;
}
}
@@ -398,21 +403,6 @@ namespace osu.Game.Beatmaps
/// The beatmap difficulty to restore.
public void Restore(BeatmapInfo beatmap) => beatmaps.Restore(beatmap);
- ///
- /// Returns a to a usable state if it has previously been deleted but not yet purged.
- /// Is a no-op for already usable beatmaps.
- ///
- /// The store to restore beatmaps from.
- /// The store to restore beatmap files from.
- /// The beatmap to restore.
- private void undelete(BeatmapSetInfo beatmapSet)
- {
- if (!beatmaps.Undelete(beatmapSet)) return;
-
- if (!beatmapSet.Protected)
- files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
- }
-
///
/// Retrieve a instance for the provided
///
diff --git a/osu.Game/Beatmaps/BeatmapStore.cs b/osu.Game/Beatmaps/BeatmapStore.cs
index 67a2bbbd90..7a1dc763f0 100644
--- a/osu.Game/Beatmaps/BeatmapStore.cs
+++ b/osu.Game/Beatmaps/BeatmapStore.cs
@@ -31,9 +31,9 @@ namespace osu.Game.Beatmaps
/// The beatmap to add.
public void Add(BeatmapSetInfo beatmapSet)
{
- using (var db = ContextFactory.GetForWrite())
+ using (var usage = ContextFactory.GetForWrite())
{
- var context = db.Context;
+ var context = usage.Context;
foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null))
{
@@ -48,6 +48,7 @@ namespace osu.Game.Beatmaps
}
context.BeatmapSetInfo.Attach(beatmapSet);
+
BeatmapSetAdded?.Invoke(beatmapSet);
}
}
@@ -73,11 +74,12 @@ namespace osu.Game.Beatmaps
/// Whether the beatmap's was changed.
public bool Delete(BeatmapSetInfo beatmapSet)
{
- using ( ContextFactory.GetForWrite())
+ using (ContextFactory.GetForWrite())
{
Refresh(ref beatmapSet, BeatmapSets);
if (beatmapSet.DeletePending) return false;
+
beatmapSet.DeletePending = true;
}
@@ -92,11 +94,12 @@ namespace osu.Game.Beatmaps
/// Whether the beatmap's was changed.
public bool Undelete(BeatmapSetInfo beatmapSet)
{
- using ( ContextFactory.GetForWrite())
+ using (ContextFactory.GetForWrite())
{
Refresh(ref beatmapSet, BeatmapSets);
if (!beatmapSet.DeletePending) return false;
+
beatmapSet.DeletePending = false;
}
@@ -116,6 +119,7 @@ namespace osu.Game.Beatmaps
Refresh(ref beatmap, Beatmaps);
if (beatmap.Hidden) return false;
+
beatmap.Hidden = true;
BeatmapHidden?.Invoke(beatmap);
@@ -136,6 +140,7 @@ namespace osu.Game.Beatmaps
Refresh(ref beatmap, Beatmaps);
if (!beatmap.Hidden) return false;
+
beatmap.Hidden = false;
}
@@ -155,7 +160,9 @@ namespace osu.Game.Beatmaps
.Where(query)
.Include(s => s.Beatmaps).ThenInclude(b => b.Metadata)
.Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty)
- .Include(s => s.Metadata);
+ .Include(s => s.Metadata).ToList();
+
+ if (!purgeable.Any()) return;
// metadata is M-N so we can't rely on cascades
context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata));
diff --git a/osu.Game/Database/DatabaseBackedStore.cs b/osu.Game/Database/DatabaseBackedStore.cs
index da66167b14..0b2f34f6d1 100644
--- a/osu.Game/Database/DatabaseBackedStore.cs
+++ b/osu.Game/Database/DatabaseBackedStore.cs
@@ -34,10 +34,7 @@ namespace osu.Game.Database
var id = obj.ID;
var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find(id);
if (foundObject != null)
- {
obj = foundObject;
- context.Entry(obj).Reload();
- }
else
context.Add(obj);
}
diff --git a/osu.Game/Database/DatabaseContextFactory.cs b/osu.Game/Database/DatabaseContextFactory.cs
index c092ed377f..2291374e46 100644
--- a/osu.Game/Database/DatabaseContextFactory.cs
+++ b/osu.Game/Database/DatabaseContextFactory.cs
@@ -1,6 +1,7 @@
// Copyright (c) 2007-2018 ppy Pty Ltd .
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
+using System.Diagnostics;
using System.Threading;
using osu.Framework.Platform;
@@ -18,6 +19,7 @@ namespace osu.Game.Database
private OsuDbContext writeContext;
+ private bool currentWriteDidWrite;
private volatile int currentWriteUsages;
public DatabaseContextFactory(GameHost host)
@@ -38,24 +40,41 @@ namespace osu.Game.Database
/// A usage containing a usable context.
public DatabaseWriteUsage GetForWrite()
{
- lock (writeLock)
- {
- var usage = new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted);
- Interlocked.Increment(ref currentWriteUsages);
- return usage;
- }
+ Monitor.Enter(writeLock);
+
+ Trace.Assert(currentWriteUsages == 0, "Database writes in a bad state");
+ Interlocked.Increment(ref currentWriteUsages);
+
+ return new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted);
}
private void usageCompleted(DatabaseWriteUsage usage)
{
int usages = Interlocked.Decrement(ref currentWriteUsages);
- if (usages == 0)
+
+ try
{
- writeContext.Dispose();
+ currentWriteDidWrite |= usage.PerformedWrite;
+
+ if (usages > 0) return;
+
+
+ if (currentWriteDidWrite)
+ {
+ writeContext.Dispose();
+ currentWriteDidWrite = false;
+
+ // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
+ recycleThreadContexts();
+ }
+
+ // always set to null (even when a write didn't occur) so we get the correct thread context on next write request.
writeContext = null;
- // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
- recycleThreadContexts();
+ }
+ finally
+ {
+ Monitor.Exit(writeLock);
}
}
diff --git a/osu.Game/Database/DatabaseWriteUsage.cs b/osu.Game/Database/DatabaseWriteUsage.cs
index 0dc5a4cfe9..52dd0ee268 100644
--- a/osu.Game/Database/DatabaseWriteUsage.cs
+++ b/osu.Game/Database/DatabaseWriteUsage.cs
@@ -19,10 +19,28 @@ namespace osu.Game.Database
usageCompleted = onCompleted;
}
+ public bool PerformedWrite { get; private set; }
+
+ private bool isDisposed;
+
+ protected void Dispose(bool disposing)
+ {
+ if (isDisposed) return;
+ isDisposed = true;
+
+ PerformedWrite |= Context.SaveChanges(transaction) > 0;
+ usageCompleted?.Invoke(this);
+ }
+
public void Dispose()
{
- Context.SaveChanges(transaction);
- usageCompleted?.Invoke(this);
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ ~DatabaseWriteUsage()
+ {
+ Dispose(false);
}
}
}
diff --git a/osu.Game/Database/OsuDbContext.cs b/osu.Game/Database/OsuDbContext.cs
index cf29ae4496..e83b30595e 100644
--- a/osu.Game/Database/OsuDbContext.cs
+++ b/osu.Game/Database/OsuDbContext.cs
@@ -111,7 +111,7 @@ namespace osu.Game.Database
public int SaveChanges(IDbContextTransaction transaction = null)
{
var ret = base.SaveChanges();
- transaction?.Commit();
+ if (ret > 0) transaction?.Commit();
return ret;
}
diff --git a/osu.Game/IO/FileStore.cs b/osu.Game/IO/FileStore.cs
index 1bfe4db81a..9889088dc4 100644
--- a/osu.Game/IO/FileStore.cs
+++ b/osu.Game/IO/FileStore.cs
@@ -30,11 +30,9 @@ namespace osu.Game.IO
{
using (var usage = ContextFactory.GetForWrite())
{
- var context = usage.Context;
-
string hash = data.ComputeSHA2Hash();
- var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash);
+ var existing = usage.Context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var info = existing ?? new FileInfo { Hash = hash };
@@ -60,6 +58,8 @@ namespace osu.Game.IO
public void Reference(params FileInfo[] files)
{
+ if (files.Length == 0) return;
+
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
@@ -75,9 +75,12 @@ namespace osu.Game.IO
public void Dereference(params FileInfo[] files)
{
+ if (files.Length == 0) return;
+
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
+
foreach (var f in files.GroupBy(f => f.ID))
{
var refetch = context.FileInfo.Find(f.Key);
diff --git a/osu.Game/Input/KeyBindingStore.cs b/osu.Game/Input/KeyBindingStore.cs
index 4aad684959..33cb0911a8 100644
--- a/osu.Game/Input/KeyBindingStore.cs
+++ b/osu.Game/Input/KeyBindingStore.cs
@@ -36,8 +36,6 @@ namespace osu.Game.Input
{
using (var usage = ContextFactory.GetForWrite())
{
- var context = usage.Context;
-
// compare counts in database vs defaults
foreach (var group in defaults.GroupBy(k => k.Action))
{
@@ -49,7 +47,7 @@ namespace osu.Game.Input
foreach (var insertable in group.Skip(count).Take(aimCount - count))
// insert any defaults which are missing.
- context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
+ usage.Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
{
KeyCombination = insertable.KeyCombination,
Action = insertable.Action,
@@ -75,6 +73,10 @@ namespace osu.Game.Input
{
var dbKeyBinding = (DatabasedKeyBinding)keyBinding;
Refresh(ref dbKeyBinding);
+
+ if (dbKeyBinding.KeyCombination.Equals(keyBinding.KeyCombination))
+ return;
+
dbKeyBinding.KeyCombination = keyBinding.KeyCombination;
}