Only write when writes occur

Also add finaliser logic for safety. Also better threading. Also more cleanup.
This commit is contained in:
Dean Herbert
2018-02-12 19:57:21 +09:00
parent edc3638175
commit 8b37fde15b
8 changed files with 86 additions and 50 deletions

View File

@ -172,7 +172,7 @@ namespace osu.Game.Beatmaps
/// <param name="archive">The beatmap to be imported.</param> /// <param name="archive">The beatmap to be imported.</param>
public BeatmapSetInfo Import(ArchiveReader archive) public BeatmapSetInfo Import(ArchiveReader archive)
{ {
using ( contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. using (contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
{ {
// create a new set info (don't yet add to database) // create a new set info (don't yet add to database)
var beatmapSet = createBeatmapSetInfo(archive); var beatmapSet = createBeatmapSetInfo(archive);
@ -181,7 +181,7 @@ namespace osu.Game.Beatmaps
var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash); var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash);
if (existingHashMatch != null) if (existingHashMatch != null)
{ {
undelete(existingHashMatch); Undelete(existingHashMatch);
return existingHashMatch; return existingHashMatch;
} }
@ -315,9 +315,9 @@ namespace osu.Game.Beatmaps
/// <param name="beatmapSet">The beatmap set to delete.</param> /// <param name="beatmapSet">The beatmap set to delete.</param>
public void Delete(BeatmapSetInfo beatmapSet) public void Delete(BeatmapSetInfo beatmapSet)
{ {
using (var db = contextFactory.GetForWrite()) using (var usage = contextFactory.GetForWrite())
{ {
var context = db.Context; var context = usage.Context;
context.ChangeTracker.AutoDetectChangesEnabled = false; context.ChangeTracker.AutoDetectChangesEnabled = false;
@ -378,11 +378,16 @@ namespace osu.Game.Beatmaps
if (beatmapSet.Protected) if (beatmapSet.Protected)
return; return;
using (var db = contextFactory.GetForWrite()) using (var usage = contextFactory.GetForWrite())
{ {
db.Context.ChangeTracker.AutoDetectChangesEnabled = false; usage.Context.ChangeTracker.AutoDetectChangesEnabled = false;
undelete(beatmapSet);
db.Context.ChangeTracker.AutoDetectChangesEnabled = true; if (!beatmaps.Undelete(beatmapSet)) return;
if (!beatmapSet.Protected)
files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
usage.Context.ChangeTracker.AutoDetectChangesEnabled = true;
} }
} }
@ -398,21 +403,6 @@ namespace osu.Game.Beatmaps
/// <param name="beatmap">The beatmap difficulty to restore.</param> /// <param name="beatmap">The beatmap difficulty to restore.</param>
public void Restore(BeatmapInfo beatmap) => beatmaps.Restore(beatmap); public void Restore(BeatmapInfo beatmap) => beatmaps.Restore(beatmap);
/// <summary>
/// Returns a <see cref="BeatmapSetInfo"/> to a usable state if it has previously been deleted but not yet purged.
/// Is a no-op for already usable beatmaps.
/// </summary>
/// <param name="beatmaps">The store to restore beatmaps from.</param>
/// <param name="files">The store to restore beatmap files from.</param>
/// <param name="beatmapSet">The beatmap to restore.</param>
private void undelete(BeatmapSetInfo beatmapSet)
{
if (!beatmaps.Undelete(beatmapSet)) return;
if (!beatmapSet.Protected)
files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
}
/// <summary> /// <summary>
/// Retrieve a <see cref="WorkingBeatmap"/> instance for the provided <see cref="BeatmapInfo"/> /// Retrieve a <see cref="WorkingBeatmap"/> instance for the provided <see cref="BeatmapInfo"/>
/// </summary> /// </summary>

View File

@ -31,9 +31,9 @@ namespace osu.Game.Beatmaps
/// <param name="beatmapSet">The beatmap to add.</param> /// <param name="beatmapSet">The beatmap to add.</param>
public void Add(BeatmapSetInfo beatmapSet) public void Add(BeatmapSetInfo beatmapSet)
{ {
using (var db = ContextFactory.GetForWrite()) using (var usage = ContextFactory.GetForWrite())
{ {
var context = db.Context; var context = usage.Context;
foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null)) foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null))
{ {
@ -48,6 +48,7 @@ namespace osu.Game.Beatmaps
} }
context.BeatmapSetInfo.Attach(beatmapSet); context.BeatmapSetInfo.Attach(beatmapSet);
BeatmapSetAdded?.Invoke(beatmapSet); BeatmapSetAdded?.Invoke(beatmapSet);
} }
} }
@ -73,11 +74,12 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns> /// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Delete(BeatmapSetInfo beatmapSet) public bool Delete(BeatmapSetInfo beatmapSet)
{ {
using ( ContextFactory.GetForWrite()) using (ContextFactory.GetForWrite())
{ {
Refresh(ref beatmapSet, BeatmapSets); Refresh(ref beatmapSet, BeatmapSets);
if (beatmapSet.DeletePending) return false; if (beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = true; beatmapSet.DeletePending = true;
} }
@ -92,11 +94,12 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns> /// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Undelete(BeatmapSetInfo beatmapSet) public bool Undelete(BeatmapSetInfo beatmapSet)
{ {
using ( ContextFactory.GetForWrite()) using (ContextFactory.GetForWrite())
{ {
Refresh(ref beatmapSet, BeatmapSets); Refresh(ref beatmapSet, BeatmapSets);
if (!beatmapSet.DeletePending) return false; if (!beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = false; beatmapSet.DeletePending = false;
} }
@ -116,6 +119,7 @@ namespace osu.Game.Beatmaps
Refresh(ref beatmap, Beatmaps); Refresh(ref beatmap, Beatmaps);
if (beatmap.Hidden) return false; if (beatmap.Hidden) return false;
beatmap.Hidden = true; beatmap.Hidden = true;
BeatmapHidden?.Invoke(beatmap); BeatmapHidden?.Invoke(beatmap);
@ -136,6 +140,7 @@ namespace osu.Game.Beatmaps
Refresh(ref beatmap, Beatmaps); Refresh(ref beatmap, Beatmaps);
if (!beatmap.Hidden) return false; if (!beatmap.Hidden) return false;
beatmap.Hidden = false; beatmap.Hidden = false;
} }
@ -155,7 +160,9 @@ namespace osu.Game.Beatmaps
.Where(query) .Where(query)
.Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata)
.Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty) .Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty)
.Include(s => s.Metadata); .Include(s => s.Metadata).ToList();
if (!purgeable.Any()) return;
// metadata is M-N so we can't rely on cascades // metadata is M-N so we can't rely on cascades
context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata)); context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata));

View File

@ -34,10 +34,7 @@ namespace osu.Game.Database
var id = obj.ID; var id = obj.ID;
var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find<T>(id); var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find<T>(id);
if (foundObject != null) if (foundObject != null)
{
obj = foundObject; obj = foundObject;
context.Entry(obj).Reload();
}
else else
context.Add(obj); context.Add(obj);
} }

View File

@ -1,6 +1,7 @@
// Copyright (c) 2007-2018 ppy Pty Ltd <contact@ppy.sh>. // Copyright (c) 2007-2018 ppy Pty Ltd <contact@ppy.sh>.
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
using System.Diagnostics;
using System.Threading; using System.Threading;
using osu.Framework.Platform; using osu.Framework.Platform;
@ -18,6 +19,7 @@ namespace osu.Game.Database
private OsuDbContext writeContext; private OsuDbContext writeContext;
private bool currentWriteDidWrite;
private volatile int currentWriteUsages; private volatile int currentWriteUsages;
public DatabaseContextFactory(GameHost host) public DatabaseContextFactory(GameHost host)
@ -38,24 +40,41 @@ namespace osu.Game.Database
/// <returns>A usage containing a usable context.</returns> /// <returns>A usage containing a usable context.</returns>
public DatabaseWriteUsage GetForWrite() public DatabaseWriteUsage GetForWrite()
{ {
lock (writeLock) Monitor.Enter(writeLock);
{
var usage = new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted); Trace.Assert(currentWriteUsages == 0, "Database writes in a bad state");
Interlocked.Increment(ref currentWriteUsages); Interlocked.Increment(ref currentWriteUsages);
return usage;
} return new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted);
} }
private void usageCompleted(DatabaseWriteUsage usage) private void usageCompleted(DatabaseWriteUsage usage)
{ {
int usages = Interlocked.Decrement(ref currentWriteUsages); int usages = Interlocked.Decrement(ref currentWriteUsages);
if (usages == 0)
try
{ {
writeContext.Dispose(); currentWriteDidWrite |= usage.PerformedWrite;
if (usages > 0) return;
if (currentWriteDidWrite)
{
writeContext.Dispose();
currentWriteDidWrite = false;
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
recycleThreadContexts();
}
// always set to null (even when a write didn't occur) so we get the correct thread context on next write request.
writeContext = null; writeContext = null;
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. }
recycleThreadContexts(); finally
{
Monitor.Exit(writeLock);
} }
} }

View File

@ -19,10 +19,28 @@ namespace osu.Game.Database
usageCompleted = onCompleted; usageCompleted = onCompleted;
} }
public bool PerformedWrite { get; private set; }
private bool isDisposed;
protected void Dispose(bool disposing)
{
if (isDisposed) return;
isDisposed = true;
PerformedWrite |= Context.SaveChanges(transaction) > 0;
usageCompleted?.Invoke(this);
}
public void Dispose() public void Dispose()
{ {
Context.SaveChanges(transaction); Dispose(true);
usageCompleted?.Invoke(this); GC.SuppressFinalize(this);
}
~DatabaseWriteUsage()
{
Dispose(false);
} }
} }
} }

View File

@ -111,7 +111,7 @@ namespace osu.Game.Database
public int SaveChanges(IDbContextTransaction transaction = null) public int SaveChanges(IDbContextTransaction transaction = null)
{ {
var ret = base.SaveChanges(); var ret = base.SaveChanges();
transaction?.Commit(); if (ret > 0) transaction?.Commit();
return ret; return ret;
} }

View File

@ -30,11 +30,9 @@ namespace osu.Game.IO
{ {
using (var usage = ContextFactory.GetForWrite()) using (var usage = ContextFactory.GetForWrite())
{ {
var context = usage.Context;
string hash = data.ComputeSHA2Hash(); string hash = data.ComputeSHA2Hash();
var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash); var existing = usage.Context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var info = existing ?? new FileInfo { Hash = hash }; var info = existing ?? new FileInfo { Hash = hash };
@ -60,6 +58,8 @@ namespace osu.Game.IO
public void Reference(params FileInfo[] files) public void Reference(params FileInfo[] files)
{ {
if (files.Length == 0) return;
using (var usage = ContextFactory.GetForWrite()) using (var usage = ContextFactory.GetForWrite())
{ {
var context = usage.Context; var context = usage.Context;
@ -75,9 +75,12 @@ namespace osu.Game.IO
public void Dereference(params FileInfo[] files) public void Dereference(params FileInfo[] files)
{ {
if (files.Length == 0) return;
using (var usage = ContextFactory.GetForWrite()) using (var usage = ContextFactory.GetForWrite())
{ {
var context = usage.Context; var context = usage.Context;
foreach (var f in files.GroupBy(f => f.ID)) foreach (var f in files.GroupBy(f => f.ID))
{ {
var refetch = context.FileInfo.Find(f.Key); var refetch = context.FileInfo.Find(f.Key);

View File

@ -36,8 +36,6 @@ namespace osu.Game.Input
{ {
using (var usage = ContextFactory.GetForWrite()) using (var usage = ContextFactory.GetForWrite())
{ {
var context = usage.Context;
// compare counts in database vs defaults // compare counts in database vs defaults
foreach (var group in defaults.GroupBy(k => k.Action)) foreach (var group in defaults.GroupBy(k => k.Action))
{ {
@ -49,7 +47,7 @@ namespace osu.Game.Input
foreach (var insertable in group.Skip(count).Take(aimCount - count)) foreach (var insertable in group.Skip(count).Take(aimCount - count))
// insert any defaults which are missing. // insert any defaults which are missing.
context.DatabasedKeyBinding.Add(new DatabasedKeyBinding usage.Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
{ {
KeyCombination = insertable.KeyCombination, KeyCombination = insertable.KeyCombination,
Action = insertable.Action, Action = insertable.Action,
@ -75,6 +73,10 @@ namespace osu.Game.Input
{ {
var dbKeyBinding = (DatabasedKeyBinding)keyBinding; var dbKeyBinding = (DatabasedKeyBinding)keyBinding;
Refresh(ref dbKeyBinding); Refresh(ref dbKeyBinding);
if (dbKeyBinding.KeyCombination.Equals(keyBinding.KeyCombination))
return;
dbKeyBinding.KeyCombination = keyBinding.KeyCombination; dbKeyBinding.KeyCombination = keyBinding.KeyCombination;
} }