diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 42a756233..e87faef78 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -143,24 +143,36 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl context.EnterScope(); - if (context.Definitions.Stage == ShaderStage.Fragment) - { - // TODO: check if it's needed - context.AppendLine("float4 position [[position]];"); - } - foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) { - string type = GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)); - string name = $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"; - string suffix = context.Definitions.Stage switch + string type = ioDefinition.IoVariable switch { - ShaderStage.Vertex => $" [[attribute({ioDefinition.Location})]]", - ShaderStage.Fragment => $" [[user(loc{ioDefinition.Location})]]", + IoVariable.Position => "float4", + IoVariable.GlobalId => "uint3", + IoVariable.VertexId => "uint", + IoVariable.VertexIndex => "uint", + _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)) + }; + string name = ioDefinition.IoVariable switch + { + IoVariable.Position => "position", + IoVariable.GlobalId => "global_id", + IoVariable.VertexId => "vertex_id", + IoVariable.VertexIndex => "vertex_index", + _ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}" + }; + string suffix = ioDefinition.IoVariable switch + { + IoVariable.Position => "[[position]]", + IoVariable.GlobalId => "[[thread_position_in_grid]]", + IoVariable.VertexId => "[[vertex_id]]", + // TODO: Avoid potential redeclaration + IoVariable.VertexIndex => "[[vertex_id]]", + IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]", _ => "" }; - context.AppendLine($"{type} {name}{suffix};"); + context.AppendLine($"{type} {name} {suffix};"); } context.LeaveScope(";"); @@ -212,14 +224,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl }; string suffix = ioDefinition.IoVariable switch { - IoVariable.Position => " [[position]]", - IoVariable.PointSize => " [[point_size]]", - IoVariable.UserDefined => $" [[user(loc{ioDefinition.Location})]]", - IoVariable.FragmentOutputColor => $" [[color({ioDefinition.Location})]]", + IoVariable.Position => "[[position]]", + IoVariable.PointSize => "[[point_size]]", + IoVariable.UserDefined => $"[[user(loc{ioDefinition.Location})]]", + IoVariable.FragmentOutputColor => $"[[color({ioDefinition.Location})]]", _ => "" }; - context.AppendLine($"{type} {name}{suffix};"); + context.AppendLine($"{type} {name} {suffix};"); } context.LeaveScope(";"); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index c836d9832..4eb4f2581 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -1,3 +1,4 @@ +using Ryujinx.Common.Logging; using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; using System.Globalization; @@ -14,7 +15,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions bool isOutput, bool isPerPatch) { - return ioVariable switch + var returnValue = ioVariable switch { IoVariable.BaseInstance => ("base_instance", AggregateType.S32), IoVariable.BaseVertex => ("base_vertex", AggregateType.S32), @@ -29,10 +30,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.VertexId => ("vertex_id", AggregateType.S32), + IoVariable.GlobalId => ("global_id", AggregateType.Vector3 | AggregateType.U32), + // gl_VertexIndex does not have a direct equivalent in MSL + IoVariable.VertexIndex => ("vertex_index", AggregateType.U32), IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32), - IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32), + IoVariable.FragmentCoord => ("position", AggregateType.Vector4 | AggregateType.FP32), _ => (null, AggregateType.Invalid), }; + + if (returnValue.Item2 == AggregateType.Invalid) + { + Logger.Warning?.PrintMsg(LogClass.Gpu, $"Unable to find type for IoVariable {ioVariable}!"); + } + + return returnValue; } private static (string, AggregateType) GetUserDefinedVariableName(ShaderDefinitions definitions, int location, int component, bool isOutput, bool isPerPatch)