diff --git a/ChocolArm64/Events/InvalidAccessEventArgs.cs b/ChocolArm64/Events/InvalidAccessEventArgs.cs index a8046d73..9c349755 100644 --- a/ChocolArm64/Events/InvalidAccessEventArgs.cs +++ b/ChocolArm64/Events/InvalidAccessEventArgs.cs @@ -2,11 +2,11 @@ using System; namespace ChocolArm64.Events { - public class InvalidAccessEventArgs : EventArgs + public class MemoryAccessEventArgs : EventArgs { public long Position { get; private set; } - public InvalidAccessEventArgs(long position) + public MemoryAccessEventArgs(long position) { Position = position; } diff --git a/ChocolArm64/Memory/MemoryManager.cs b/ChocolArm64/Memory/MemoryManager.cs index eacb5336..1f212568 100644 --- a/ChocolArm64/Memory/MemoryManager.cs +++ b/ChocolArm64/Memory/MemoryManager.cs @@ -17,18 +17,18 @@ namespace ChocolArm64.Memory { private const int PtLvl0Bits = 13; private const int PtLvl1Bits = 14; - private const int PtPageBits = 12; + public const int PageBits = 12; private const int PtLvl0Size = 1 << PtLvl0Bits; private const int PtLvl1Size = 1 << PtLvl1Bits; - public const int PageSize = 1 << PtPageBits; + public const int PageSize = 1 << PageBits; private const int PtLvl0Mask = PtLvl0Size - 1; private const int PtLvl1Mask = PtLvl1Size - 1; public const int PageMask = PageSize - 1; - private const int PtLvl0Bit = PtPageBits + PtLvl1Bits; - private const int PtLvl1Bit = PtPageBits; + private const int PtLvl0Bit = PageBits + PtLvl1Bits; + private const int PtLvl1Bit = PageBits; private const long ErgMask = (4 << CpuThreadState.ErgSizeLog2) - 1; @@ -53,7 +53,9 @@ namespace ChocolArm64.Memory private byte*** _pageTable; - public event EventHandler InvalidAccess; + public event EventHandler InvalidAccess; + + public event EventHandler ObservedAccess; public MemoryManager(IntPtr ram) { @@ -632,7 +634,7 @@ namespace ChocolArm64.Memory return false; } - return _pageTable[l0][l1] != null || _observedPages.ContainsKey(position >> PtPageBits); + return _pageTable[l0][l1] != null || _observedPages.ContainsKey(position >> PageBits); } public long GetPhysicalAddress(long virtualAddress) @@ -678,14 +680,14 @@ Unmapped: private byte* HandleNullPte(long position) { - long key = position >> PtPageBits; + long key = position >> PageBits; if (_observedPages.TryGetValue(key, out IntPtr ptr)) { return (byte*)ptr + (position & PageMask); } - InvalidAccess?.Invoke(this, new InvalidAccessEventArgs(position)); + InvalidAccess?.Invoke(this, new MemoryAccessEventArgs(position)); throw new VmmPageFaultException(position); } @@ -726,16 +728,20 @@ Unmapped: private byte* HandleNullPteWrite(long position) { - long key = position >> PtPageBits; + long key = position >> PageBits; + + MemoryAccessEventArgs e = new MemoryAccessEventArgs(position); if (_observedPages.TryGetValue(key, out IntPtr ptr)) { SetPtEntry(position, (byte*)ptr); + ObservedAccess?.Invoke(this, e); + return (byte*)ptr + (position & PageMask); } - InvalidAccess?.Invoke(this, new InvalidAccessEventArgs(position)); + InvalidAccess?.Invoke(this, e); throw new VmmPageFaultException(position); } @@ -784,53 +790,20 @@ Unmapped: _pageTable[l0][l1] = ptr; } - public (bool[], int) IsRegionModified(long position, long size) + public void StartObservingRegion(long position, long size) { long endPosition = (position + size + PageMask) & ~PageMask; position &= ~PageMask; - size = endPosition - position; - - bool[] modified = new bool[size >> PtPageBits]; - - int count = 0; - - lock (_observedPages) + while ((ulong)position < (ulong)endPosition) { - for (int page = 0; page < modified.Length; page++) - { - byte* ptr = Translate(position); + _observedPages[position >> PageBits] = (IntPtr)Translate(position); - if (_observedPages.TryAdd(position >> PtPageBits, (IntPtr)ptr)) - { - modified[page] = true; + SetPtEntry(position, null); - count++; - } - else - { - long l0 = (position >> PtLvl0Bit) & PtLvl0Mask; - long l1 = (position >> PtLvl1Bit) & PtLvl1Mask; - - byte** lvl1 = _pageTable[l0]; - - if (lvl1 != null) - { - if (modified[page] = lvl1[l1] != null) - { - count++; - } - } - } - - SetPtEntry(position, null); - - position += PageSize; - } + position += PageSize; } - - return (modified, count); } public void StopObservingRegion(long position, long size) @@ -841,7 +814,7 @@ Unmapped: { lock (_observedPages) { - if (_observedPages.TryRemove(position >> PtPageBits, out IntPtr ptr)) + if (_observedPages.TryRemove(position >> PageBits, out IntPtr ptr)) { SetPtEntry(position, (byte*)ptr); } @@ -891,7 +864,7 @@ Unmapped: public bool IsValidPosition(long position) { - return position >> (PtLvl0Bits + PtLvl1Bits + PtPageBits) == 0; + return position >> (PtLvl0Bits + PtLvl1Bits + PageBits) == 0; } public void Dispose() diff --git a/Ryujinx.Graphics/Memory/NvGpuVmm.cs b/Ryujinx.Graphics/Memory/NvGpuVmm.cs index cfd1aaeb..7fdef473 100644 --- a/Ryujinx.Graphics/Memory/NvGpuVmm.cs +++ b/Ryujinx.Graphics/Memory/NvGpuVmm.cs @@ -36,7 +36,7 @@ namespace Ryujinx.Graphics.Memory { this.Memory = Memory; - Cache = new NvGpuVmmCache(); + Cache = new NvGpuVmmCache(Memory); PageTable = new long[PTLvl0Size][]; } @@ -262,7 +262,7 @@ namespace Ryujinx.Graphics.Memory public bool IsRegionModified(long PA, long Size, NvGpuBufferType BufferType) { - return Cache.IsRegionModified(Memory, BufferType, PA, Size); + return Cache.IsRegionModified(PA, Size, BufferType); } public bool TryGetHostAddress(long Position, long Size, out IntPtr Ptr) diff --git a/Ryujinx.Graphics/Memory/NvGpuVmmCache.cs b/Ryujinx.Graphics/Memory/NvGpuVmmCache.cs index dd6d37c9..2f50463d 100644 --- a/Ryujinx.Graphics/Memory/NvGpuVmmCache.cs +++ b/Ryujinx.Graphics/Memory/NvGpuVmmCache.cs @@ -1,130 +1,83 @@ +using ChocolArm64.Events; using ChocolArm64.Memory; -using System; +using System.Collections.Concurrent; namespace Ryujinx.Graphics.Memory { class NvGpuVmmCache { - private struct CachedResource + private const int PageBits = MemoryManager.PageBits; + + private const long PageSize = MemoryManager.PageSize; + private const long PageMask = MemoryManager.PageMask; + + private ConcurrentDictionary[] CachedPages; + + private MemoryManager _memory; + + public NvGpuVmmCache(MemoryManager memory) { - public long Key; - public int Mask; + _memory = memory; - public CachedResource(long Key, int Mask) - { - this.Key = Key; - this.Mask = Mask; - } + _memory.ObservedAccess += MemoryAccessHandler; - public override int GetHashCode() - { - return (int)(Key * 23 + Mask); - } - - public override bool Equals(object obj) - { - return obj is CachedResource Cached && Equals(Cached); - } - - public bool Equals(CachedResource other) - { - return Key == other.Key && Mask == other.Mask; - } + CachedPages = new ConcurrentDictionary[1 << 20]; } - private ValueRangeSet CachedRanges; - - public NvGpuVmmCache() + private void MemoryAccessHandler(object sender, MemoryAccessEventArgs e) { - CachedRanges = new ValueRangeSet(); + long pa = _memory.GetPhysicalAddress(e.Position); + + CachedPages[pa >> PageBits]?.Clear(); } - public bool IsRegionModified(MemoryManager Memory, NvGpuBufferType BufferType, long Start, long Size) + public bool IsRegionModified(long position, long size, NvGpuBufferType bufferType) { - (bool[] Modified, long ModifiedCount) = Memory.IsRegionModified(Start, Size); + long pa = _memory.GetPhysicalAddress(position); - //Remove all modified ranges. - int Index = 0; + long addr = pa; - long Position = Start & ~NvGpuVmm.PageMask; + long endAddr = (addr + size + PageMask) & ~PageMask; - while (ModifiedCount > 0) + int newBuffMask = 1 << (int)bufferType; + + _memory.StartObservingRegion(position, size); + + long cachedPagesCount = 0; + + while (addr < endAddr) { - if (Modified[Index++]) - { - CachedRanges.Remove(new ValueRange(Position, Position + NvGpuVmm.PageSize)); + long page = addr >> PageBits; - ModifiedCount--; + ConcurrentDictionary dictionary = CachedPages[page]; + + if (dictionary == null) + { + dictionary = new ConcurrentDictionary(); + + CachedPages[page] = dictionary; } - Position += NvGpuVmm.PageSize; + if (dictionary.TryGetValue(pa, out int currBuffMask)) + { + if ((currBuffMask & newBuffMask) != 0) + { + cachedPagesCount++; + } + else + { + dictionary[pa] |= newBuffMask; + } + } + else + { + dictionary[pa] = newBuffMask; + } + + addr += PageSize; } - //Mask has the bit set for the current resource type. - //If the region is not yet present on the list, then a new ValueRange - //is directly added with the current resource type as the only bit set. - //Otherwise, it just sets the bit for this new resource type on the current mask. - //The physical address of the resource is used as key, those keys are used to keep - //track of resources that are already on the cache. A resource may be inside another - //resource, and in this case we should return true if the "sub-resource" was not - //yet cached. - int Mask = 1 << (int)BufferType; - - CachedResource NewCachedValue = new CachedResource(Start, Mask); - - ValueRange NewCached = new ValueRange(Start, Start + Size); - - ValueRange[] Ranges = CachedRanges.GetAllIntersections(NewCached); - - bool IsKeyCached = Ranges.Length > 0 && Ranges[0].Value.Key == Start; - - long LastEnd = NewCached.Start; - - long Coverage = 0; - - for (Index = 0; Index < Ranges.Length; Index++) - { - ValueRange Current = Ranges[Index]; - - CachedResource Cached = Current.Value; - - long RgStart = Math.Max(Current.Start, NewCached.Start); - long RgEnd = Math.Min(Current.End, NewCached.End); - - if ((Cached.Mask & Mask) != 0) - { - Coverage += RgEnd - RgStart; - } - - //Highest key value has priority, this prevents larger resources - //for completely invalidating smaller ones on the cache. For example, - //consider that a resource in the range [100, 200) was added, and then - //another one in the range [50, 200). We prevent the new resource from - //completely replacing the old one by spliting it like this: - //New resource key is added at [50, 100), old key is still present at [100, 200). - if (Cached.Key < Start) - { - Cached.Key = Start; - } - - Cached.Mask |= Mask; - - CachedRanges.Add(new ValueRange(RgStart, RgEnd, Cached)); - - if (RgStart > LastEnd) - { - CachedRanges.Add(new ValueRange(LastEnd, RgStart, NewCachedValue)); - } - - LastEnd = RgEnd; - } - - if (LastEnd < NewCached.End) - { - CachedRanges.Add(new ValueRange(LastEnd, NewCached.End, NewCachedValue)); - } - - return !IsKeyCached || Coverage != Size; + return cachedPagesCount != (endAddr - pa + PageMask) >> PageBits; } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/ValueRange.cs b/Ryujinx.Graphics/ValueRange.cs deleted file mode 100644 index 6298bd8e..00000000 --- a/Ryujinx.Graphics/ValueRange.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace Ryujinx.Graphics -{ - struct ValueRange - { - public long Start { get; private set; } - public long End { get; private set; } - - public T Value { get; set; } - - public ValueRange(long Start, long End, T Value = default(T)) - { - this.Start = Start; - this.End = End; - this.Value = Value; - } - } -} \ No newline at end of file diff --git a/Ryujinx.Graphics/ValueRangeSet.cs b/Ryujinx.Graphics/ValueRangeSet.cs deleted file mode 100644 index 42125bce..00000000 --- a/Ryujinx.Graphics/ValueRangeSet.cs +++ /dev/null @@ -1,234 +0,0 @@ -using System.Collections.Generic; - -namespace Ryujinx.Graphics -{ - class ValueRangeSet - { - private List> Ranges; - - public ValueRangeSet() - { - Ranges = new List>(); - } - - public void Add(ValueRange Range) - { - if (Range.End <= Range.Start) - { - //Empty or invalid range, do nothing. - return; - } - - int First = BinarySearchFirstIntersection(Range); - - if (First == -1) - { - //No intersections case. - //Find first greater than range (after the current one). - //If found, add before, otherwise add to the end of the list. - int GtIndex = BinarySearchGt(Range); - - if (GtIndex != -1) - { - Ranges.Insert(GtIndex, Range); - } - else - { - Ranges.Add(Range); - } - - return; - } - - (int Start, int End) = GetAllIntersectionRanges(Range, First); - - ValueRange Prev = Ranges[Start]; - ValueRange Next = Ranges[End]; - - Ranges.RemoveRange(Start, (End - Start) + 1); - - InsertNextNeighbour(Start, Range, Next); - - int NewIndex = Start; - - Ranges.Insert(Start, Range); - - InsertPrevNeighbour(Start, Range, Prev); - - //Try merging neighbours if the value is equal. - if (NewIndex > 0) - { - Prev = Ranges[NewIndex - 1]; - - if (Prev.End == Range.Start && CompareValues(Prev, Range)) - { - Ranges.RemoveAt(--NewIndex); - - Ranges[NewIndex] = new ValueRange(Prev.Start, Range.End, Range.Value); - } - } - - if (NewIndex < Ranges.Count - 1) - { - Next = Ranges[NewIndex + 1]; - - if (Next.Start == Range.End && CompareValues(Next, Range)) - { - Ranges.RemoveAt(NewIndex + 1); - - Ranges[NewIndex] = new ValueRange(Ranges[NewIndex].Start, Next.End, Range.Value); - } - } - } - - private bool CompareValues(ValueRange LHS, ValueRange RHS) - { - return LHS.Value?.Equals(RHS.Value) ?? RHS.Value == null; - } - - public void Remove(ValueRange Range) - { - int First = BinarySearchFirstIntersection(Range); - - if (First == -1) - { - //Nothing to remove. - return; - } - - (int Start, int End) = GetAllIntersectionRanges(Range, First); - - ValueRange Prev = Ranges[Start]; - ValueRange Next = Ranges[End]; - - Ranges.RemoveRange(Start, (End - Start) + 1); - - InsertNextNeighbour(Start, Range, Next); - InsertPrevNeighbour(Start, Range, Prev); - } - - private void InsertNextNeighbour(int Index, ValueRange Range, ValueRange Next) - { - //Split last intersection (ordered by Start) if necessary. - if (Range.End < Next.End) - { - InsertNewRange(Index, Range.End, Next.End, Next.Value); - } - } - - private void InsertPrevNeighbour(int Index, ValueRange Range, ValueRange Prev) - { - //Split first intersection (ordered by Start) if necessary. - if (Range.Start > Prev.Start) - { - InsertNewRange(Index, Prev.Start, Range.Start, Prev.Value); - } - } - - private void InsertNewRange(int Index, long Start, long End, T Value) - { - Ranges.Insert(Index, new ValueRange(Start, End, Value)); - } - - public ValueRange[] GetAllIntersections(ValueRange Range) - { - int First = BinarySearchFirstIntersection(Range); - - if (First == -1) - { - return new ValueRange[0]; - } - - (int Start, int End) = GetAllIntersectionRanges(Range, First); - - return Ranges.GetRange(Start, (End - Start) + 1).ToArray(); - } - - private (int Start, int End) GetAllIntersectionRanges(ValueRange Range, int BaseIndex) - { - int Start = BaseIndex; - int End = BaseIndex; - - while (Start > 0 && Intersects(Range, Ranges[Start - 1])) - { - Start--; - } - - while (End < Ranges.Count - 1 && Intersects(Range, Ranges[End + 1])) - { - End++; - } - - return (Start, End); - } - - private int BinarySearchFirstIntersection(ValueRange Range) - { - int Left = 0; - int Right = Ranges.Count - 1; - - while (Left <= Right) - { - int Size = Right - Left; - - int Middle = Left + (Size >> 1); - - ValueRange Current = Ranges[Middle]; - - if (Intersects(Range, Current)) - { - return Middle; - } - - if (Range.Start < Current.Start) - { - Right = Middle - 1; - } - else - { - Left = Middle + 1; - } - } - - return -1; - } - - private int BinarySearchGt(ValueRange Range) - { - int GtIndex = -1; - - int Left = 0; - int Right = Ranges.Count - 1; - - while (Left <= Right) - { - int Size = Right - Left; - - int Middle = Left + (Size >> 1); - - ValueRange Current = Ranges[Middle]; - - if (Range.Start < Current.Start) - { - Right = Middle - 1; - - if (GtIndex == -1 || Current.Start < Ranges[GtIndex].Start) - { - GtIndex = Middle; - } - } - else - { - Left = Middle + 1; - } - } - - return GtIndex; - } - - private bool Intersects(ValueRange LHS, ValueRange RHS) - { - return LHS.Start < RHS.End && RHS.Start < LHS.End; - } - } -} \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Kernel/KProcess.cs b/Ryujinx.HLE/HOS/Kernel/KProcess.cs index c5cfd964..6d91f41c 100644 --- a/Ryujinx.HLE/HOS/Kernel/KProcess.cs +++ b/Ryujinx.HLE/HOS/Kernel/KProcess.cs @@ -995,7 +995,7 @@ namespace Ryujinx.HLE.HOS.Kernel } } - private void InvalidAccessHandler(object sender, InvalidAccessEventArgs e) + private void InvalidAccessHandler(object sender, MemoryAccessEventArgs e) { PrintCurrentThreadStackTrace(); }