Brainfuck compiler
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bf-jit-unsafe-opt.c 16KB

3 years ago

  1. // This is an exercise in futility more than anything else
  2. #define _GNU_SOURCE
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <stdint.h>
  7. #include <stdbool.h>
  8. #include <assert.h>
  9. #include <errno.h>
  10. #if (defined __x86_64__ || defined __amd64__) && defined __unix__
  11. #include <unistd.h>
  12. #include <sys/mman.h>
  13. #else
  14. #error Platform not supported
  15. #endif
  16. #define exit_fatal(...) \
  17. do { \
  18. fprintf (stderr, "fatal: " __VA_ARGS__); \
  19. exit (EXIT_FAILURE); \
  20. } while (0)
  21. // --- Safe memory management --------------------------------------------------
  22. static void *
  23. xcalloc (size_t m, size_t n)
  24. {
  25. void *p = calloc (m, n);
  26. if (!p)
  27. exit_fatal ("calloc: %s\n", strerror (errno));
  28. return p;
  29. }
  30. static void *
  31. xrealloc (void *o, size_t n)
  32. {
  33. void *p = realloc (o, n);
  34. if (!p && n)
  35. exit_fatal ("realloc: %s\n", strerror (errno));
  36. return p;
  37. }
  38. // --- Dynamically allocated strings -------------------------------------------
  39. struct str
  40. {
  41. char *str; ///< String data, null terminated
  42. size_t alloc; ///< How many bytes are allocated
  43. size_t len; ///< How long the string actually is
  44. };
  45. static void
  46. str_init (struct str *self)
  47. {
  48. self->len = 0;
  49. self->str = xcalloc (1, (self->alloc = 16));
  50. }
  51. static void
  52. str_ensure_space (struct str *self, size_t n)
  53. {
  54. // We allocate at least one more byte for the terminating null character
  55. size_t new_alloc = self->alloc;
  56. while (new_alloc <= self->len + n)
  57. new_alloc <<= 1;
  58. if (new_alloc != self->alloc)
  59. self->str = xrealloc (self->str, (self->alloc = new_alloc));
  60. }
  61. static void
  62. str_append_data (struct str *self, const void *data, size_t n)
  63. {
  64. str_ensure_space (self, n);
  65. memcpy (self->str + self->len, data, n);
  66. self->str[self->len += n] = '\0';
  67. }
  68. static void
  69. str_append_c (struct str *self, char c)
  70. {
  71. str_append_data (self, &c, 1);
  72. }
  73. // --- Application -------------------------------------------------------------
  74. enum command { RIGHT, LEFT, INC, DEC, SET, IN, OUT, BEGIN, END,
  75. EAT, INCACC, DECACC };
  76. bool grouped[] = { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 };
  77. struct instruction { enum command cmd; int offset; size_t arg; };
  78. #define INSTRUCTION(c, o, a) (struct instruction) { (c), (o), (a) }
  79. // - - Callbacks - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  80. FILE *input; ///< User input
  81. static int
  82. cin (void)
  83. {
  84. int c = fgetc (input);
  85. assert (c != EOF);
  86. return c;
  87. }
  88. // - - Main - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  89. #ifdef DEBUG
  90. static void
  91. debug_dump (const char *filename, struct instruction *in, size_t len)
  92. {
  93. FILE *fp = fopen (filename, "w");
  94. long indent = 0;
  95. for (size_t i = 0; i < len; i++)
  96. {
  97. if (in[i].cmd == END)
  98. indent--;
  99. for (long k = 0; k < indent; k++)
  100. fprintf (fp, " ");
  101. switch (in[i].cmd)
  102. {
  103. case RIGHT: fputs ("RIGHT ", fp); break;
  104. case LEFT: fputs ("LEFT ", fp); break;
  105. case INC: fputs ("INC ", fp); break;
  106. case DEC: fputs ("DEC ", fp); break;
  107. case OUT: fputs ("OUT ", fp); break;
  108. case IN: fputs ("IN ", fp); break;
  109. case BEGIN: fputs ("BEGIN ", fp); break;
  110. case END: fputs ("END ", fp); break;
  111. case SET: fputs ("SET ", fp); break;
  112. case EAT: fputs ("EAT ", fp); break;
  113. case INCACC: fputs ("INCACC", fp); break;
  114. case DECACC: fputs ("DECACC", fp); break;
  115. }
  116. fprintf (fp, " %zu [%d]\n", in[i].arg, in[i].offset);
  117. if (in[i].cmd == BEGIN)
  118. indent++;
  119. }
  120. fclose (fp);
  121. }
  122. #else
  123. #define debug_dump(...)
  124. #endif
  125. int
  126. main (int argc, char *argv[])
  127. {
  128. (void) argc;
  129. (void) argv;
  130. struct str program;
  131. str_init (&program);
  132. int c;
  133. while ((c = fgetc (stdin)) != EOF)
  134. str_append_c (&program, c);
  135. if (ferror (stdin))
  136. exit_fatal ("can't read program\n");
  137. if (!(input = fopen ("/dev/tty", "rb")))
  138. exit_fatal ("can't open terminal for reading\n");
  139. // - - Decode and group - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  140. struct instruction *parsed = xcalloc (sizeof *parsed, program.len);
  141. size_t parsed_len = 0;
  142. for (size_t i = 0; i < program.len; i++)
  143. {
  144. enum command cmd;
  145. switch (program.str[i])
  146. {
  147. case '>': cmd = RIGHT; break;
  148. case '<': cmd = LEFT; break;
  149. case '+': cmd = INC; break;
  150. case '-': cmd = DEC; break;
  151. case '.': cmd = OUT; break;
  152. case ',': cmd = IN; break;
  153. case '[': cmd = BEGIN; break;
  154. case ']': cmd = END; break;
  155. default: continue;
  156. }
  157. // The most basic optimization is to group identical commands together
  158. if (!parsed_len || !grouped[cmd] || parsed[parsed_len - 1].cmd != cmd)
  159. parsed_len++;
  160. parsed[parsed_len - 1].cmd = cmd;
  161. parsed[parsed_len - 1].arg++;
  162. }
  163. // - - Optimization passes - - - - - - - - - - - - - - - - - - - - - - - - - - -
  164. debug_dump ("bf-no-opt.txt", parsed, parsed_len);
  165. size_t in = 0, out = 0;
  166. for (; in < parsed_len; in++, out++)
  167. {
  168. if (in + 2 < parsed_len
  169. && parsed[in ].cmd == BEGIN
  170. && parsed[in + 1].cmd == DEC && parsed[in + 1].arg == 1
  171. && parsed[in + 2].cmd == END)
  172. {
  173. parsed[out] = INSTRUCTION (SET, 0, 0);
  174. in += 2;
  175. }
  176. else if (out && parsed[out - 1].cmd == SET && parsed[in].cmd == INC)
  177. parsed[--out].arg += parsed[in].arg;
  178. else if (out != in)
  179. parsed[out] = parsed[in];
  180. }
  181. parsed_len = out;
  182. debug_dump ("bf-pre-offsets.txt", parsed, parsed_len);
  183. // Add offsets to INC/DEC/SET stuck between LEFT/RIGHT
  184. // and compress the LEFT/RIGHT sequences
  185. for (in = 0, out = 0; in < parsed_len; in++, out++)
  186. {
  187. ssize_t dir = 0;
  188. if (parsed[in].cmd == RIGHT)
  189. dir = parsed[in].arg;
  190. else if (parsed[in].cmd == LEFT)
  191. dir = -(ssize_t) parsed[in].arg;
  192. else
  193. {
  194. parsed[out] = parsed[in];
  195. continue;
  196. }
  197. while (in + 2 < parsed_len)
  198. {
  199. // An immediate offset has its limits
  200. if (dir < INT8_MIN || dir > INT8_MAX)
  201. break;
  202. ssize_t diff;
  203. if (parsed[in + 2].cmd == RIGHT)
  204. diff = parsed[in + 2].arg;
  205. else if (parsed[in + 2].cmd == LEFT)
  206. diff = -(ssize_t) parsed[in + 2].arg;
  207. else
  208. break;
  209. int cmd = parsed[in + 1].cmd;
  210. if (cmd != INC && cmd != DEC && cmd != SET)
  211. break;
  212. parsed[out] = parsed[in + 1];
  213. parsed[out].offset = dir;
  214. dir += diff;
  215. out += 1;
  216. in += 2;
  217. }
  218. for (; in + 1 < parsed_len; in++)
  219. {
  220. if (parsed[in + 1].cmd == RIGHT)
  221. dir += parsed[in + 1].arg;
  222. else if (parsed[in + 1].cmd == LEFT)
  223. dir -= (ssize_t) parsed[in + 1].arg;
  224. else
  225. break;
  226. }
  227. if (!dir)
  228. out--;
  229. else if (dir > 0)
  230. parsed[out] = INSTRUCTION (RIGHT, 0, dir);
  231. else
  232. parsed[out] = INSTRUCTION (LEFT, 0, -dir);
  233. }
  234. parsed_len = out;
  235. debug_dump ("bf-pre-incdec-unloop.txt", parsed, parsed_len);
  236. // Try to eliminate loops that eat a cell and add/subtract its value
  237. // to/from some other cell
  238. for (in = 0, out = 0; in < parsed_len; in++, out++)
  239. {
  240. parsed[out] = parsed[in];
  241. if (parsed[in].cmd != BEGIN)
  242. continue;
  243. bool ok = false;
  244. size_t count = 0;
  245. for (size_t k = in + 1; k < parsed_len; k++)
  246. {
  247. if (parsed[k].cmd == END)
  248. {
  249. ok = true;
  250. break;
  251. }
  252. if (parsed[k].cmd != INC
  253. && parsed[k].cmd != DEC)
  254. break;
  255. count++;
  256. }
  257. if (!ok)
  258. continue;
  259. // Stable sort operations by their offsets, put [0] first
  260. bool sorted;
  261. do
  262. {
  263. sorted = true;
  264. for (size_t k = 1; k < count; k++)
  265. {
  266. if (parsed[in + k].offset == 0)
  267. continue;
  268. if (parsed[in + k + 1].offset != 0
  269. && parsed[in + k].offset <= parsed[in + k + 1].offset)
  270. continue;
  271. struct instruction tmp = parsed[in + k + 1];
  272. parsed[in + k + 1] = parsed[in + k];
  273. parsed[in + k] = tmp;
  274. sorted = false;
  275. }
  276. }
  277. while (!sorted);
  278. // Abort the optimization on duplicate offsets (complication with [0])
  279. for (size_t k = 1; k < count; k++)
  280. if (parsed[in + k].offset == parsed[in + k + 1].offset)
  281. ok = false;
  282. // XXX: can't make the code longer either
  283. for (size_t k = 1; k <= count; k++)
  284. if (parsed[in + k].arg != 1)
  285. ok = false;
  286. if (!ok
  287. || parsed[in + 1].cmd != DEC
  288. || parsed[in + 1].offset != 0)
  289. continue;
  290. int min_safe_left_offset = 0;
  291. if (in > 1 && parsed[in - 1].cmd == RIGHT)
  292. min_safe_left_offset = -parsed[in - 1].arg;
  293. bool cond_needed_for_safety = false;
  294. for (size_t k = 0; k < count; k++)
  295. if (parsed[in + k + 1].offset < min_safe_left_offset)
  296. {
  297. cond_needed_for_safety = true;
  298. break;
  299. }
  300. in++;
  301. if (cond_needed_for_safety)
  302. out++;
  303. parsed[out] = INSTRUCTION (EAT, 0, 0);
  304. for (size_t k = 1; k < count; k++)
  305. parsed[out + k] = INSTRUCTION (parsed[in + k].cmd == INC
  306. ? INCACC : DECACC, parsed[in + k].offset, 0);
  307. in += count;
  308. out += count;
  309. if (cond_needed_for_safety)
  310. parsed[out] = INSTRUCTION (END, 0, 0);
  311. else
  312. out--;
  313. }
  314. parsed_len = out;
  315. debug_dump ("bf-optimized.txt", parsed, parsed_len);
  316. // - - Loop pairing - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  317. size_t nesting = 0;
  318. size_t *stack = xcalloc (sizeof *stack, parsed_len);
  319. for (size_t i = 0; i < parsed_len; i++)
  320. {
  321. switch (parsed[i].cmd)
  322. {
  323. case BEGIN:
  324. stack[nesting++] = i;
  325. break;
  326. case END:
  327. assert (nesting > 0);
  328. --nesting;
  329. parsed[stack[nesting]].arg = i + 1;
  330. // Looping can be disabled by optimizations
  331. if (parsed[i].arg)
  332. parsed[i].arg = stack[nesting] + 1;
  333. default:
  334. break;
  335. }
  336. }
  337. free (stack);
  338. assert (nesting == 0);
  339. debug_dump ("bf-final.txt", parsed, parsed_len);
  340. // - - JIT - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  341. // Functions preserve the registers rbx, rsp, rbp, r12, r13, r14, and r15;
  342. // while rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 are scratch registers.
  343. str_init (&program);
  344. size_t *offsets = xcalloc (sizeof *offsets, parsed_len + 1);
  345. uint8_t *arith = xcalloc (sizeof *arith, parsed_len);
  346. #define CODE(x) { char t[] = x; str_append_data (&program, t, sizeof t - 1); }
  347. #define WORD(x) { size_t t = (size_t)(x); str_append_data (&program, &t, 8); }
  348. #define DWRD(x) { size_t t = (size_t)(x); str_append_data (&program, &t, 4); }
  349. CODE ("\x48\x89\xF8") // mov rax, rdi
  350. CODE ("\x30\xDB") // xor bl, bl
  351. for (size_t i = 0; i < parsed_len; i++)
  352. {
  353. offsets[i] = program.len;
  354. size_t arg = parsed[i].arg;
  355. assert (arg <= UINT32_MAX);
  356. int offset = parsed[i].offset;
  357. assert (offset <= INT8_MAX && offset >= INT8_MIN);
  358. // Don't save what we've just loaded
  359. if (parsed[i].cmd == LEFT || parsed[i].cmd == RIGHT)
  360. if (i < 2 || i + 1 >= parsed_len
  361. || (parsed[i - 2].cmd != LEFT && parsed[i - 2].cmd != RIGHT)
  362. || parsed[i - 1].cmd != BEGIN
  363. || parsed[i + 1].cmd != END)
  364. CODE ("\x88\x18") // mov [rax], bl
  365. switch (parsed[i].cmd)
  366. {
  367. case RIGHT:
  368. // add rax, "arg" -- optimistic, no boundary checking
  369. if (arg > INT8_MAX)
  370. { CODE ("\x48\x05") DWRD (arg) }
  371. else
  372. { CODE ("\x48\x83\xC0") str_append_c (&program, arg); }
  373. break;
  374. case LEFT:
  375. // sub rax, "arg" -- optimistic, no boundary checking
  376. if (arg > INT8_MAX)
  377. { CODE ("\x48\x2D") DWRD (arg) }
  378. else
  379. { CODE ("\x48\x83\xE8") str_append_c (&program, arg); }
  380. break;
  381. case EAT:
  382. CODE ("\x41\x88\xDC") // mov r12b, bl
  383. CODE ("\x30\xDB") // xor bl, bl
  384. arith[i] = 1;
  385. break;
  386. case INCACC:
  387. if (offset)
  388. {
  389. CODE ("\x44\x00\x60") // add [rax+"offset"], r12b
  390. str_append_c (&program, offset);
  391. }
  392. else
  393. {
  394. CODE ("\x44\x00\xE3") // add bl, r12b
  395. arith[i] = 1;
  396. }
  397. break;
  398. case DECACC:
  399. if (offset)
  400. {
  401. CODE ("\x44\x28\x60") // sub [rax+"offset"], r12b
  402. str_append_c (&program, offset);
  403. }
  404. else
  405. {
  406. CODE ("\x44\x28\xE3") // sub bl, r12b
  407. arith[i] = 1;
  408. }
  409. break;
  410. case INC:
  411. if (offset)
  412. {
  413. CODE ("\x80\x40") // add byte [rax+"offset"], "arg"
  414. str_append_c (&program, offset);
  415. }
  416. else
  417. {
  418. arith[i] = 1;
  419. CODE ("\x80\xC3") // add bl, "arg"
  420. }
  421. str_append_c (&program, arg);
  422. break;
  423. case DEC:
  424. if (offset)
  425. {
  426. CODE ("\x80\x68") // sub byte [rax+"offset"], "arg"
  427. str_append_c (&program, offset);
  428. }
  429. else
  430. {
  431. arith[i] = 1;
  432. CODE ("\x80\xEB") // sub bl, "arg"
  433. }
  434. str_append_c (&program, arg);
  435. break;
  436. case SET:
  437. if (offset)
  438. {
  439. CODE ("\xC6\x40") // mov byte [rax+"offset"], "arg"
  440. str_append_c (&program, offset);
  441. }
  442. else
  443. CODE ("\xB3") // mov bl, "arg"
  444. str_append_c (&program, arg);
  445. break;
  446. case OUT:
  447. CODE ("\x50\x53") // push rax, push rbx
  448. CODE ("\x48\x0F\xB6\xFB") // movzx rdi, bl
  449. CODE ("\x48\xBE") WORD (stdout) // mov rsi, "stdout"
  450. CODE ("\x48\xB8") WORD (fputc) // mov rax, "fputc"
  451. CODE ("\xFF\xD0") // call rax
  452. CODE ("\x5B\x58") // pop rbx, pop rax
  453. break;
  454. case IN:
  455. CODE ("\x50") // push rax
  456. CODE ("\x48\xB8") WORD (cin) // mov rax, "cin"
  457. CODE ("\xFF\xD0") // call rax
  458. CODE ("\x88\xC3") // mov bl, al
  459. CODE ("\x58") // pop rax
  460. break;
  461. case BEGIN:
  462. // Don't test the register when the flag has been set already;
  463. // this doesn't have much of an effect in practice
  464. if (!i || !arith[i - 1])
  465. CODE ("\x84\xDB") // test bl, bl
  466. CODE ("\x0F\x84\x00\x00\x00\x00") // jz "offsets[i]"
  467. break;
  468. case END:
  469. // We know that the cell is zero, make this an "if", not a "loop";
  470. // this doesn't have much of an effect in practice
  471. if (!arg)
  472. break;
  473. if (!i || !arith[i - 1])
  474. CODE ("\x84\xDB") // test bl, bl
  475. CODE ("\x0F\x85\x00\x00\x00\x00") // jnz "offsets[i]"
  476. break;
  477. }
  478. // No sense in reading it out when we overwrite it immediately;
  479. // this doesn't have much of an effect in practice
  480. if (parsed[i].cmd == LEFT || parsed[i].cmd == RIGHT)
  481. if (i + 1 >= parsed_len
  482. || parsed[i + 1].cmd != SET
  483. || parsed[i + 1].offset != 0)
  484. CODE ("\x8A\x18") // mov bl, [rax]
  485. }
  486. // When there is a loop at the end we need to be able to jump past it
  487. offsets[parsed_len] = program.len;
  488. str_append_c (&program, '\xC3'); // ret
  489. // Now that we know where each instruction is, fill in relative jumps;
  490. // this must accurately reflect code generators for BEGIN and END
  491. for (size_t i = 0; i < parsed_len; i++)
  492. {
  493. if ((parsed[i].cmd != BEGIN && parsed[i].cmd != END)
  494. || !parsed[i].arg)
  495. continue;
  496. size_t fixup = offsets[i] + 2;
  497. if (!i || !arith[i - 1])
  498. fixup += 2;
  499. *(int32_t *)(program.str + fixup) =
  500. ((intptr_t)(offsets[parsed[i].arg]) - (intptr_t)(fixup + 4));
  501. }
  502. free (offsets);
  503. free (arith);
  504. #ifdef DEBUG
  505. FILE *bin = fopen ("bf-jit.bin", "w");
  506. fwrite (program.str, program.len, 1, bin);
  507. fclose (bin);
  508. #endif
  509. // - - Runtime - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  510. // Some systems may have W^X
  511. void *executable = mmap (NULL, program.len, PROT_READ | PROT_WRITE,
  512. MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
  513. if (!executable)
  514. exit_fatal ("mmap: %s\n", strerror (errno));
  515. memcpy (executable, program.str, program.len);
  516. if (mprotect (executable, program.len, PROT_READ | PROT_EXEC))
  517. exit_fatal ("mprotect: %s\n", strerror (errno));
  518. // We create crash zones on both ends of the tape for some minimum safety
  519. long pagesz = sysconf (_SC_PAGESIZE);
  520. assert (pagesz > 0);
  521. const size_t tape_len = (1 << 20) + 2 * pagesz;
  522. char *tape = mmap (NULL, tape_len, PROT_READ | PROT_WRITE,
  523. MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
  524. if (!tape)
  525. exit_fatal ("mmap: %s\n", strerror (errno));
  526. memset (tape, 0, tape_len);
  527. if (mprotect (tape, pagesz, PROT_NONE)
  528. || mprotect (tape + tape_len - pagesz, pagesz, PROT_NONE))
  529. exit_fatal ("mprotect: %s\n", strerror (errno));
  530. ((void (*) (char *)) executable)(tape + pagesz);
  531. return 0;
  532. }