From a9dd6e30adc590e11e3a076c1245f1b0b48f27f6 Mon Sep 17 00:00:00 2001
From: RSDuck <rsduck@users.noreply.github.com>
Date: Sat, 25 Apr 2020 19:35:40 +0200
Subject: implement msr and mrs for the x64 JIT

---
 src/ARMJIT.cpp                     |   2 +-
 src/ARMJIT_x64/ARMJIT_Compiler.cpp | 127 ++++++++++++++++++++++++++++++++++++-
 src/ARMJIT_x64/ARMJIT_Compiler.h   |   3 +
 src/ARM_InstrInfo.cpp              |   4 ++
 4 files changed, 134 insertions(+), 2 deletions(-)

diff --git a/src/ARMJIT.cpp b/src/ARMJIT.cpp
index cc8d4ce..46f71f1 100644
--- a/src/ARMJIT.cpp
+++ b/src/ARMJIT.cpp
@@ -824,7 +824,7 @@ void InvalidateITCM(u32 addr)
 
 void InvalidateAll()
 {
-	JIT_DEBUGPRINT("invalidating all %x\n", JitBlocks.Length);
+	JIT_DEBUGPRINT("invalidating all %x\n", JitBlocks.size());
 	for (auto it : JitBlocks)
 	{
 		JitBlock* block = it.second;
diff --git a/src/ARMJIT_x64/ARMJIT_Compiler.cpp b/src/ARMJIT_x64/ARMJIT_Compiler.cpp
index 1b2d312..52a16dc 100644
--- a/src/ARMJIT_x64/ARMJIT_Compiler.cpp
+++ b/src/ARMJIT_x64/ARMJIT_Compiler.cpp
@@ -38,6 +38,131 @@ const int RegisterCache<Compiler, X64Reg>::NativeRegsAvailable =
 #endif
 ;
 
+void Compiler::A_Comp_MRS()
+{
+    Comp_AddCycles_C();
+
+    OpArg rd = MapReg(CurInstr.A_Reg(12));
+
+    if (CurInstr.Instr & (1 << 22))
+    {
+        MOV(32, R(RSCRATCH), R(RCPSR));
+        AND(32, R(RSCRATCH), Imm8(0x1F));
+        XOR(32, R(ABI_PARAM3), R(ABI_PARAM3));
+        MOV(32, R(ABI_PARAM2), Imm32(15 - 8));
+        CALL(ReadBanked);
+        MOV(32, rd, R(ABI_PARAM3));
+    }
+    else
+        MOV(32, rd, R(RCPSR));
+}
+
+void Compiler::A_Comp_MSR()
+{
+    Comp_AddCycles_C();
+
+    OpArg val = CurInstr.Instr & (1 << 25)
+        ? Imm32(ROR((CurInstr.Instr & 0xFF), ((CurInstr.Instr >> 7) & 0x1E)))
+        : MapReg(CurInstr.A_Reg(0));
+
+    u32 mask = 0;
+    if (CurInstr.Instr & (1<<16)) mask |= 0x000000FF;
+    if (CurInstr.Instr & (1<<17)) mask |= 0x0000FF00;
+    if (CurInstr.Instr & (1<<18)) mask |= 0x00FF0000;
+    if (CurInstr.Instr & (1<<19)) mask |= 0xFF000000;
+
+    if (CurInstr.Instr & (1 << 22))
+    {
+        MOV(32, R(RSCRATCH), R(RCPSR));
+        AND(32, R(RSCRATCH), Imm8(0x1F));
+        XOR(32, R(ABI_PARAM3), R(ABI_PARAM3));
+        MOV(32, R(ABI_PARAM2), Imm32(15 - 8));
+        CALL(ReadBanked);
+
+        MOV(32, R(RSCRATCH2), Imm32(0xFFFFFF00));
+        MOV(32, R(RSCRATCH3), Imm32(0xFFFFFFFF));
+        MOV(32, R(RSCRATCH), R(RCPSR));
+        AND(32, R(RSCRATCH), Imm8(0x1F));
+        CMP(32, R(RSCRATCH), Imm8(0x10));
+        CMOVcc(32, RSCRATCH2, R(RSCRATCH3), CC_NE);
+        AND(32, R(RSCRATCH2), Imm32(mask));
+
+        MOV(32, R(RSCRATCH), R(RSCRATCH2));
+        NOT(32, R(RSCRATCH));
+        AND(32, R(ABI_PARAM3), R(RSCRATCH));
+
+        AND(32, R(RSCRATCH2), val);
+        OR(32, R(ABI_PARAM3), R(RSCRATCH2));
+
+        MOV(32, R(RSCRATCH), R(RCPSR));
+        AND(32, R(RSCRATCH), Imm8(0x1F));
+        MOV(32, R(ABI_PARAM2), Imm32(15 - 8));
+        CALL(WriteBanked);
+    }
+    else
+    {
+        mask &= 0xFFFFFFDF;
+        CPSRDirty = true;
+
+        if ((mask & 0xFF) == 0)
+        {
+            AND(32, R(RCPSR), Imm32(~mask));
+            if (val.IsImm())
+            {
+                MOV(32, R(RSCRATCH), val);
+                AND(32, R(RSCRATCH), Imm32(mask));
+                OR(32, R(RCPSR), R(RSCRATCH));
+            }
+            else
+            {
+                OR(32, R(RCPSR), Imm32(val.Imm32() & mask));
+            }
+        }
+        else
+        {
+            MOV(32, R(RSCRATCH2), Imm32(mask));
+            MOV(32, R(RSCRATCH3), R(RSCRATCH2));
+            AND(32, R(RSCRATCH3), Imm32(0xFFFFFF00));
+            MOV(32, R(RSCRATCH), R(RCPSR));
+            AND(32, R(RSCRATCH), Imm8(0x1F));
+            CMP(32, R(RSCRATCH), Imm8(0x10));
+            CMOVcc(32, RSCRATCH2, R(RSCRATCH3), CC_E);
+
+            MOV(32, R(RSCRATCH3), R(RCPSR));
+
+            // I need you ANDN
+            MOV(32, R(RSCRATCH), R(RSCRATCH2));
+            NOT(32, R(RSCRATCH));
+            AND(32, R(RCPSR), R(RSCRATCH));
+
+            AND(32, R(RSCRATCH2), val);
+            OR(32, R(RCPSR), R(RSCRATCH2));
+
+            BitSet16 hiRegsLoaded(RegCache.LoadedRegs & 0x7F00);
+            if (Thumb || CurInstr.Cond() >= 0xE)
+                RegCache.Flush();
+            else
+            {
+                // the ugly way...
+                // we only save them, to load and save them again
+                for (int reg : hiRegsLoaded)
+                    SaveReg(reg, RegCache.Mapping[reg]);
+            }
+
+            MOV(32, R(ABI_PARAM3), R(RCPSR));
+            MOV(32, R(ABI_PARAM2), R(RSCRATCH3));
+            MOV(64, R(ABI_PARAM1), R(RCPU));
+            CALL((void*)&ARM::UpdateMode);
+
+            if (!Thumb && CurInstr.Cond() < 0xE)
+            {
+                for (int reg : hiRegsLoaded)
+                    LoadReg(reg, RegCache.Mapping[reg]);
+            }
+        }
+    }
+}
+
 /*
     We'll repurpose this .bss memory
 
@@ -328,7 +453,7 @@ const Compiler::CompileFunc A_Comp[ARMInstrInfo::ak_Count] =
     // Branch
     F(A_Comp_BranchImm), F(A_Comp_BranchImm), F(A_Comp_BranchImm), F(A_Comp_BranchXchangeReg), F(A_Comp_BranchXchangeReg),
     // system stuff
-    NULL, NULL, NULL, NULL, NULL, NULL, NULL,
+    NULL, F(A_Comp_MSR), F(A_Comp_MSR), F(A_Comp_MRS), NULL, NULL, NULL,
     F(Nop)
 };
 
diff --git a/src/ARMJIT_x64/ARMJIT_Compiler.h b/src/ARMJIT_x64/ARMJIT_Compiler.h
index a448b6d..2230eb8 100644
--- a/src/ARMJIT_x64/ARMJIT_Compiler.h
+++ b/src/ARMJIT_x64/ARMJIT_Compiler.h
@@ -100,6 +100,9 @@ public:
     void A_Comp_BranchImm();
     void A_Comp_BranchXchangeReg();
 
+    void A_Comp_MRS();
+    void A_Comp_MSR();
+
     void T_Comp_ShiftImm();
     void T_Comp_AddSub_();
     void T_Comp_ALU_Imm8();
diff --git a/src/ARM_InstrInfo.cpp b/src/ARM_InstrInfo.cpp
index b884773..28362d9 100644
--- a/src/ARM_InstrInfo.cpp
+++ b/src/ARM_InstrInfo.cpp
@@ -427,6 +427,10 @@ Info Decode(bool thumb, u32 num, u32 instr)
                 res.Kind = ak_UNK;
             }
         }
+        if (res.Kind == ak_MRS && !(instr & (1 << 22)))
+            res.ReadFlags |= flag_N | flag_Z | flag_C | flag_V;
+        if ((res.Kind == ak_MSR_IMM || res.Kind == ak_MSR_REG) && instr & (1 << 19))
+            res.WriteFlags |= flag_N | flag_Z | flag_C | flag_V;
 
         if (data & A_Read0)
             res.SrcRegs |= 1 << (instr & 0xF);
-- 
cgit v1.2.3