diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs index 4ec153c78f..41ea293938 100644 --- a/osu.Game/Beatmaps/BeatmapManager.cs +++ b/osu.Game/Beatmaps/BeatmapManager.cs @@ -172,7 +172,7 @@ namespace osu.Game.Beatmaps /// The beatmap to be imported. public BeatmapSetInfo Import(ArchiveReader archive) { - using ( contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. + using (contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. { // create a new set info (don't yet add to database) var beatmapSet = createBeatmapSetInfo(archive); @@ -181,7 +181,7 @@ namespace osu.Game.Beatmaps var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash); if (existingHashMatch != null) { - undelete(existingHashMatch); + Undelete(existingHashMatch); return existingHashMatch; } @@ -315,9 +315,9 @@ namespace osu.Game.Beatmaps /// The beatmap set to delete. public void Delete(BeatmapSetInfo beatmapSet) { - using (var db = contextFactory.GetForWrite()) + using (var usage = contextFactory.GetForWrite()) { - var context = db.Context; + var context = usage.Context; context.ChangeTracker.AutoDetectChangesEnabled = false; @@ -378,11 +378,16 @@ namespace osu.Game.Beatmaps if (beatmapSet.Protected) return; - using (var db = contextFactory.GetForWrite()) + using (var usage = contextFactory.GetForWrite()) { - db.Context.ChangeTracker.AutoDetectChangesEnabled = false; - undelete(beatmapSet); - db.Context.ChangeTracker.AutoDetectChangesEnabled = true; + usage.Context.ChangeTracker.AutoDetectChangesEnabled = false; + + if (!beatmaps.Undelete(beatmapSet)) return; + + if (!beatmapSet.Protected) + files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray()); + + usage.Context.ChangeTracker.AutoDetectChangesEnabled = true; } } @@ -398,21 +403,6 @@ namespace osu.Game.Beatmaps /// The beatmap difficulty to restore. public void Restore(BeatmapInfo beatmap) => beatmaps.Restore(beatmap); - /// - /// Returns a to a usable state if it has previously been deleted but not yet purged. - /// Is a no-op for already usable beatmaps. - /// - /// The store to restore beatmaps from. - /// The store to restore beatmap files from. - /// The beatmap to restore. - private void undelete(BeatmapSetInfo beatmapSet) - { - if (!beatmaps.Undelete(beatmapSet)) return; - - if (!beatmapSet.Protected) - files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray()); - } - /// /// Retrieve a instance for the provided /// diff --git a/osu.Game/Beatmaps/BeatmapStore.cs b/osu.Game/Beatmaps/BeatmapStore.cs index 67a2bbbd90..7a1dc763f0 100644 --- a/osu.Game/Beatmaps/BeatmapStore.cs +++ b/osu.Game/Beatmaps/BeatmapStore.cs @@ -31,9 +31,9 @@ namespace osu.Game.Beatmaps /// The beatmap to add. public void Add(BeatmapSetInfo beatmapSet) { - using (var db = ContextFactory.GetForWrite()) + using (var usage = ContextFactory.GetForWrite()) { - var context = db.Context; + var context = usage.Context; foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null)) { @@ -48,6 +48,7 @@ namespace osu.Game.Beatmaps } context.BeatmapSetInfo.Attach(beatmapSet); + BeatmapSetAdded?.Invoke(beatmapSet); } } @@ -73,11 +74,12 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Delete(BeatmapSetInfo beatmapSet) { - using ( ContextFactory.GetForWrite()) + using (ContextFactory.GetForWrite()) { Refresh(ref beatmapSet, BeatmapSets); if (beatmapSet.DeletePending) return false; + beatmapSet.DeletePending = true; } @@ -92,11 +94,12 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Undelete(BeatmapSetInfo beatmapSet) { - using ( ContextFactory.GetForWrite()) + using (ContextFactory.GetForWrite()) { Refresh(ref beatmapSet, BeatmapSets); if (!beatmapSet.DeletePending) return false; + beatmapSet.DeletePending = false; } @@ -116,6 +119,7 @@ namespace osu.Game.Beatmaps Refresh(ref beatmap, Beatmaps); if (beatmap.Hidden) return false; + beatmap.Hidden = true; BeatmapHidden?.Invoke(beatmap); @@ -136,6 +140,7 @@ namespace osu.Game.Beatmaps Refresh(ref beatmap, Beatmaps); if (!beatmap.Hidden) return false; + beatmap.Hidden = false; } @@ -155,7 +160,9 @@ namespace osu.Game.Beatmaps .Where(query) .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) .Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty) - .Include(s => s.Metadata); + .Include(s => s.Metadata).ToList(); + + if (!purgeable.Any()) return; // metadata is M-N so we can't rely on cascades context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata)); diff --git a/osu.Game/Database/DatabaseBackedStore.cs b/osu.Game/Database/DatabaseBackedStore.cs index da66167b14..0b2f34f6d1 100644 --- a/osu.Game/Database/DatabaseBackedStore.cs +++ b/osu.Game/Database/DatabaseBackedStore.cs @@ -34,10 +34,7 @@ namespace osu.Game.Database var id = obj.ID; var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find(id); if (foundObject != null) - { obj = foundObject; - context.Entry(obj).Reload(); - } else context.Add(obj); } diff --git a/osu.Game/Database/DatabaseContextFactory.cs b/osu.Game/Database/DatabaseContextFactory.cs index c092ed377f..2291374e46 100644 --- a/osu.Game/Database/DatabaseContextFactory.cs +++ b/osu.Game/Database/DatabaseContextFactory.cs @@ -1,6 +1,7 @@ // Copyright (c) 2007-2018 ppy Pty Ltd . // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE +using System.Diagnostics; using System.Threading; using osu.Framework.Platform; @@ -18,6 +19,7 @@ namespace osu.Game.Database private OsuDbContext writeContext; + private bool currentWriteDidWrite; private volatile int currentWriteUsages; public DatabaseContextFactory(GameHost host) @@ -38,24 +40,41 @@ namespace osu.Game.Database /// A usage containing a usable context. public DatabaseWriteUsage GetForWrite() { - lock (writeLock) - { - var usage = new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted); - Interlocked.Increment(ref currentWriteUsages); - return usage; - } + Monitor.Enter(writeLock); + + Trace.Assert(currentWriteUsages == 0, "Database writes in a bad state"); + Interlocked.Increment(ref currentWriteUsages); + + return new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted); } private void usageCompleted(DatabaseWriteUsage usage) { int usages = Interlocked.Decrement(ref currentWriteUsages); - if (usages == 0) + + try { - writeContext.Dispose(); + currentWriteDidWrite |= usage.PerformedWrite; + + if (usages > 0) return; + + + if (currentWriteDidWrite) + { + writeContext.Dispose(); + currentWriteDidWrite = false; + + // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. + recycleThreadContexts(); + } + + // always set to null (even when a write didn't occur) so we get the correct thread context on next write request. writeContext = null; - // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. - recycleThreadContexts(); + } + finally + { + Monitor.Exit(writeLock); } } diff --git a/osu.Game/Database/DatabaseWriteUsage.cs b/osu.Game/Database/DatabaseWriteUsage.cs index 0dc5a4cfe9..52dd0ee268 100644 --- a/osu.Game/Database/DatabaseWriteUsage.cs +++ b/osu.Game/Database/DatabaseWriteUsage.cs @@ -19,10 +19,28 @@ namespace osu.Game.Database usageCompleted = onCompleted; } + public bool PerformedWrite { get; private set; } + + private bool isDisposed; + + protected void Dispose(bool disposing) + { + if (isDisposed) return; + isDisposed = true; + + PerformedWrite |= Context.SaveChanges(transaction) > 0; + usageCompleted?.Invoke(this); + } + public void Dispose() { - Context.SaveChanges(transaction); - usageCompleted?.Invoke(this); + Dispose(true); + GC.SuppressFinalize(this); + } + + ~DatabaseWriteUsage() + { + Dispose(false); } } } diff --git a/osu.Game/Database/OsuDbContext.cs b/osu.Game/Database/OsuDbContext.cs index cf29ae4496..e83b30595e 100644 --- a/osu.Game/Database/OsuDbContext.cs +++ b/osu.Game/Database/OsuDbContext.cs @@ -111,7 +111,7 @@ namespace osu.Game.Database public int SaveChanges(IDbContextTransaction transaction = null) { var ret = base.SaveChanges(); - transaction?.Commit(); + if (ret > 0) transaction?.Commit(); return ret; } diff --git a/osu.Game/IO/FileStore.cs b/osu.Game/IO/FileStore.cs index 1bfe4db81a..9889088dc4 100644 --- a/osu.Game/IO/FileStore.cs +++ b/osu.Game/IO/FileStore.cs @@ -30,11 +30,9 @@ namespace osu.Game.IO { using (var usage = ContextFactory.GetForWrite()) { - var context = usage.Context; - string hash = data.ComputeSHA2Hash(); - var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash); + var existing = usage.Context.FileInfo.FirstOrDefault(f => f.Hash == hash); var info = existing ?? new FileInfo { Hash = hash }; @@ -60,6 +58,8 @@ namespace osu.Game.IO public void Reference(params FileInfo[] files) { + if (files.Length == 0) return; + using (var usage = ContextFactory.GetForWrite()) { var context = usage.Context; @@ -75,9 +75,12 @@ namespace osu.Game.IO public void Dereference(params FileInfo[] files) { + if (files.Length == 0) return; + using (var usage = ContextFactory.GetForWrite()) { var context = usage.Context; + foreach (var f in files.GroupBy(f => f.ID)) { var refetch = context.FileInfo.Find(f.Key); diff --git a/osu.Game/Input/KeyBindingStore.cs b/osu.Game/Input/KeyBindingStore.cs index 4aad684959..33cb0911a8 100644 --- a/osu.Game/Input/KeyBindingStore.cs +++ b/osu.Game/Input/KeyBindingStore.cs @@ -36,8 +36,6 @@ namespace osu.Game.Input { using (var usage = ContextFactory.GetForWrite()) { - var context = usage.Context; - // compare counts in database vs defaults foreach (var group in defaults.GroupBy(k => k.Action)) { @@ -49,7 +47,7 @@ namespace osu.Game.Input foreach (var insertable in group.Skip(count).Take(aimCount - count)) // insert any defaults which are missing. - context.DatabasedKeyBinding.Add(new DatabasedKeyBinding + usage.Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding { KeyCombination = insertable.KeyCombination, Action = insertable.Action, @@ -75,6 +73,10 @@ namespace osu.Game.Input { var dbKeyBinding = (DatabasedKeyBinding)keyBinding; Refresh(ref dbKeyBinding); + + if (dbKeyBinding.KeyCombination.Equals(keyBinding.KeyCombination)) + return; + dbKeyBinding.KeyCombination = keyBinding.KeyCombination; }