chibicc icon indicating copy to clipboard operation
chibicc copied to clipboard

LLVM codegen

Open ghost opened this issue 3 years ago • 0 comments

Hello! :wave: Thank you for making this, I feel like it’s an incredibly helpful learning resource, and overall a really neat project!

In case anyone finds it useful, I wanted to mention that I decided to start writing an LLVM backend for chibicc! It’s still nowhere nearly finished, but it’s enough to compile some simple examples.

codegen.c
#include <inttypes.h>
#include "chibicc.h"

int align_to(int n, int align) {
  return (n + align - 1) / align * align;
}

static FILE *output_file;

__attribute__((format(printf, 1, 2)))
static void println(char *fmt, ...) {
  va_list ap;
  va_start(ap, fmt);
  vfprintf(output_file, fmt, ap);
  va_end(ap);
  fprintf(output_file, "\n");
}

static char *gen_addr(Node *node);
static char *gen_expr(Node *node);
static char *gen_stmt(Node *node);

static int counter = 0;
static int counter2 = 0;

static char *to_val(char *str) {
  counter++;
  println("  %%%d = %s", counter, str);
  return format("%%%d", counter);
}

static int count() {
  counter2++;
  return counter2;
}

static char *gen_type(Type *ty) {
  switch (ty->kind) {
  case TY_VOID: return "void";
  case TY_BOOL: return "i32";
  case TY_CHAR: return "i8";
  case TY_SHORT: return "i16";
  case TY_INT: return "i32";
  case TY_ENUM: return "i32";
  case TY_LONG: return "i64";
  case TY_FLOAT: return "f32";
  case TY_DOUBLE: return "f64";
  case TY_LDOUBLE: return "x86_f80";
  case TY_ARRAY:
    return format("[%d x %s]", ty->array_len, gen_type(ty->base));
  case TY_PTR:
    if (ty->base->kind == TY_VOID) return "i8*";
    return format("%s*", gen_type(ty->base));
  // TODO
  // case TY_STRUCT:
  // case TY_UNION:
  // ...
  }
  unreachable();
}

char *to_bool(Node *node) {
  if (is_flonum(node->ty)) {
    return to_val(format("fcmp ne %s %s, 0.0", gen_type(node->ty), gen_expr(node)));
  } else {
    return to_val(format("icmp ne %s %s, 0", gen_type(node->ty), gen_expr(node)));
  }
}

char *fn_args(Node *arg) {
  if (!arg) return NULL;
  char *next = fn_args(arg->next);
  char *v = gen_expr(arg);
  if (next) return format("%s %s, %s", gen_type(arg->ty), v, next);
  else return format("%s %s", gen_type(arg->ty), v);
}

static char *gen_addr(Node *node) {
  switch (node->kind) {
  case ND_VAR: {
    if (node->var->is_local)
      return format("%%%s.%d", node->var->name, node->var->offset);
    char *v = format("@%s", node->var->name);
    if (node->ty->kind == TY_ARRAY) {
      int size = node->ty->size / node->ty->base->size;
      char *ty = gen_type(node->ty->base);
      v = to_val(format("getelementptr [%d x %s], [%d x %s]* %s, i8 0, i8 0", size, ty, size, ty, v));
    }
    return v;
  }
  case ND_DEREF:
    return gen_expr(node->lhs);
  case ND_COMMA:
    println("  %s", gen_expr(node->lhs));
    return gen_addr(node->rhs);
  case ND_ASSIGN:
  case ND_COND:
    if (node->ty->kind == TY_STRUCT || node->ty->kind == TY_UNION) {
      // TODO
      error_tok(node->tok, "not implemented");
    }
    break;
  case ND_FUNCALL:
  case ND_MEMBER:
  case ND_VLA_PTR:
    // TODO
    error_tok(node->tok, "not implemented");
  }

  error_tok(node->tok, "not an lvalue");
}

static char *gen_expr(Node *node) {
  switch (node->kind) {
  case ND_NULL_EXPR:
    return NULL;
  case ND_NUM:
    switch (node->ty->kind) {
    case TY_FLOAT:
    case TY_DOUBLE:
    case TY_LDOUBLE:
      return format("%Lf", node->fval);
    default:
      return format("%ld", node->val);
    }
  case ND_NEG:
    if (is_flonum(node->ty))
      return to_val(format("fneg %s %s", gen_type(node->ty), gen_expr(node->lhs)));
    else
      return to_val(format("sub %s %s, 0", gen_type(node->ty), gen_expr(node->lhs)));
  case ND_VAR:
    if (node->ty->kind == TY_ARRAY)
      return gen_addr(node);
    return to_val(format("load %s, %s* %s", gen_type(node->ty), gen_type(node->ty), gen_addr(node)));
  case ND_MEMBER:
    return to_val(format("load %s, %s* %s", gen_type(node->member->ty), gen_type(node->member->ty), gen_addr(node)));
  case ND_DEREF:
    return to_val(format("load %s, %s %s", gen_type(node->ty), gen_type(node->lhs->ty), gen_expr(node->lhs)));
  case ND_ADDR:
    return gen_addr(node->lhs);
  case ND_ASSIGN: {
    if (node->lhs->kind == ND_MEMBER && node->lhs->member->is_bitfield) {
      // TODO
      error_tok(node->tok, "not implemented");
    }
    char *v1 = gen_addr(node->lhs);
    char *v2 = gen_expr(node->rhs);
    println("  store %s %s, %s* %s", gen_type(node->rhs->ty), v2, gen_type(node->lhs->ty), v1);
    return v2;
  }
  case ND_STMT_EXPR: {
    char *res;
    for (Node *n = node->body; n; n = n->next)
      res = gen_stmt(n);
    if (res == NULL) error_tok(node->tok, "invalid last statement kind");
    return res;
  }
  case ND_COMMA:
    gen_expr(node->lhs);
    return gen_expr(node->rhs);
  case ND_CAST: {
    char *kw;
    if (is_integer(node->ty) && is_integer(node->lhs->ty)) {
      if (node->ty->size < node->lhs->ty->size)
        kw = "trunc";
      else if (node->ty->size > node->lhs->ty->size)
        kw = "zext";
      else
        return gen_expr(node->lhs);
    }
    else if (is_flonum(node->ty) && is_flonum(node->lhs->ty)) {
      if (node->ty->size < node->lhs->ty->size)
        kw = "fptrunc";
      else if (node->ty->size > node->lhs->ty->size)
        kw = "fpext";
      else
        return gen_expr(node->lhs);
    }
    else if (is_flonum(node->ty) && is_integer(node->lhs->ty)) {
      if (node->lhs->ty->is_unsigned)
        kw = "uitofp";
      else
        kw = "sitofp";
    }
    else if (is_integer(node->ty) && is_flonum(node->lhs->ty)) {
      if (node->lhs->ty->is_unsigned)
        kw = "fptoui";
      else
        kw = "fptosi";
    }
    else if (node->ty->kind == TY_PTR && is_integer(node->lhs->ty)) {
      kw = "inttoptr";
    }
    else if (is_integer(node->ty) && node->lhs->ty->kind == TY_PTR) {
      kw = "ptrtoint";
    }
    else if (node->ty->kind == TY_PTR && node->lhs->ty->kind == TY_ARRAY) {
      return gen_addr(node->lhs);
    }
    else {
      return gen_expr(node->lhs);
    }
    return to_val(format("%s %s %s to %s", kw, gen_type(node->lhs->ty), gen_expr(node->lhs), gen_type(node->ty)));
  }
  case ND_MEMZERO:
    return "zeroinitializer";
  case ND_COND: {
    int c = count();
    println("  br i1 %s, label %%L.then.%d, label %%L.else.%d", to_bool(node->cond), c, c);
    println("L.then.%d:", c);
    char *v1 = gen_expr(node->then);
    println("  br label %%L.end.%d", c);
    println("L.else.%d:", c);
    char *v2 = gen_expr(node->els);
    println("  br label %%L.end.%d", c);
    println("L.end.%d:", c);
    return to_val(format("phi %s [ %s %%L.then.%d, %s %%L.else.%d ]", gen_type(node->ty), v1, c, v2, c));
  }
  case ND_NOT: {
    char *v;
    if (is_flonum(node->lhs->ty)) {
      v = to_val(format("fcmp eq %s %s, 0.0", gen_type(node->lhs->ty), gen_expr(node->lhs)));
    } else {
      v = to_val(format("icmp eq %s %s, 0", gen_type(node->lhs->ty), gen_expr(node->lhs)));
    }
    return to_val(format("zext i1 %s to %s", v, gen_type(node->lhs->ty)));
  }
  case ND_BITNOT:
    return to_val(format("xor %s %s, -1", gen_type(node->lhs->ty), gen_expr(node->lhs)));
  case ND_LOGAND: {
    int c = count();
    char *v1 = to_bool(node->lhs);
    println("  br label %%L.and.%d", c);
    println("L.and.%d:", c);
    println("  br i1 %s, label %%L.true.%d, label %%L.end.%d", v1, c, c);
    println("L.true.%d:", c);
    char *v2 = to_bool(node->rhs);
    println("  br label %%L.end.%d", c);
    println("L.end.%d:", c);
    return to_val(format("phi %s [ 0 %%L.and.%d, %s %%L.true.%d ]", gen_type(node->ty), c, v2, c));
  }
  case ND_LOGOR: {
    int c = count();
    char *v1 = to_bool(node->lhs);
    println("  br label %%L.or.%d", c);
    println("L.or.%d:", c);
    println("  br i1 %s, label %%L.end.%d, label %%L.false.%d", v1, c, c);
    println("L.false.%d:", c);
    char *v2 = to_bool(node->rhs);
    println("  br label %%L.end.%d", c);
    println("L.end.%d:", c);
    return to_val(format("phi %s [ 1 %%L.or.%d, %s %%L.false.%d ]", gen_type(node->ty), c, v2, c));
  }
  case ND_FUNCALL: {
    if (node->lhs->kind == ND_VAR && !strcmp(node->lhs->var->name, "alloca")) {
      // TODO
      error_tok(node->tok, "not implemented");
    }
    char *args = fn_args(node->args);
    if (node->ty->kind == TY_VOID) {
      println("  call void %s(%s)", gen_addr(node->lhs), args);
      return NULL;
    }
    return to_val(format("call %s %s(%s)", gen_type(node->ty), gen_addr(node->lhs), args));
  }
  case ND_ADD:
  case ND_SUB:
  case ND_MUL:
  case ND_DIV:
  case ND_MOD:
  case ND_BITAND:
  case ND_BITOR:
  case ND_BITXOR:
  case ND_EQ:
  case ND_NE:
  case ND_LT:
  case ND_LE: {
    char *v1 = gen_expr(node->lhs);
    char *v2 = gen_expr(node->rhs);

    char *kw;
    switch (node->kind) {
    case ND_ADD: kw = is_flonum(node->lhs->ty) ? "fadd" : "add"; break;
    case ND_SUB: kw = is_flonum(node->lhs->ty) ? "fsub" : "sub"; break;
    case ND_MUL: kw = is_flonum(node->lhs->ty) ? "fmul" : "mul"; break;
    case ND_DIV: kw = is_flonum(node->lhs->ty) ? "fdiv" : node->lhs->ty->is_unsigned ? "udiv" : "sdiv"; break;
    case ND_MOD: kw = is_flonum(node->lhs->ty) ? "frem" : node->lhs->ty->is_unsigned ? "urem" : "srem"; break;
    case ND_BITAND: kw = "and"; break;
    case ND_BITOR: kw = "or"; break;
    case ND_BITXOR: kw = "xor"; break;
    case ND_EQ: kw = is_flonum(node->lhs->ty) ? "fcmp oeq" : "icmp eq"; break;
    case ND_NE: kw = is_flonum(node->lhs->ty) ? "fcmp une" : "icmp ne"; break;
    case ND_LT: kw = is_flonum(node->lhs->ty) ? "fcmp ult" : node->lhs->ty->is_unsigned ? "icmp ult" : "icmp slt"; break;
    case ND_LE: kw = is_flonum(node->lhs->ty) ? "fcmp ule" : node->lhs->ty->is_unsigned ? "icmp ule" : "icmp sle"; break;
    }

    char *ty = gen_type(node->lhs->ty);
    if (node->ty->kind == TY_PTR) {
      ty = "i64";
      if (node->lhs->ty-> kind == TY_PTR) {
        v1 = to_val(format("ptrtoint %s %s to i64", gen_type(node->lhs->ty), v1));
        if (node->rhs->ty-> kind != TY_PTR)
          v2 = to_val(format("mul i64 %s, 8", v2));
      }
      if (node->rhs->ty-> kind == TY_PTR) {
        v2 = to_val(format("ptrtoint %s %s to i64", gen_type(node->rhs->ty), v2));
        if (node->lhs->ty-> kind != TY_PTR)
          v1 = to_val(format("mul i64 %s, 8", v1));
      }
    }

    char *v = to_val(format("%s %s %s, %s", kw, ty, v1, v2));

    if (node->ty->kind == TY_PTR)
      v = to_val(format("inttoptr i64 %s to %s", v, gen_type(node->ty)));

    switch (node->kind) {
    case ND_EQ:
    case ND_NE:
    case ND_LT:
    case ND_LE:
      return to_val(format("zext i1 %s to %s", v, gen_type(node->ty)));
    }

    return v;
  }
  case ND_LABEL_VAL:
  case ND_CAS:
  case ND_EXCH:
    // TODO
    error_tok(node->tok, "not implemented");
  }

  error_tok(node->tok, "invalid expression");
}

static char *gen_stmt(Node *node) {
  switch (node->kind) {
  case ND_IF: {
    int c = count();
    println("  br i1 %s, label %%L.then.%d, label %%L.else.%d", to_bool(node->cond), c, c);
    println("L.then.%d:", c);
    gen_stmt(node->then);
    println("  br label %%L.end.%d", c);
    println("L.else.%d:", c);
    gen_stmt(node->els);
    println("  br label %%L.end.%d", c);
    println("L.end.%d:", c);
    return NULL;
  }
  case ND_FOR: {
    int c = count();
    if (node->init) gen_stmt(node->init);
    println("  br label %%L.continue.%d", c);
    println("L.continue.%d:", c);
    if (node->cond) {
      char *v1 = to_bool(node->cond);
      println("  br i1 %s, label %%L.then.%d, label %%L.break.%d", v1, c, c);
    }
    println("L.then.%d:", c);
    gen_stmt(node->then);
    if (node->inc) gen_expr(node->inc);
    println("  br label %%L.continue.%d", c);
    println("L.break.%d:", c);
    return NULL;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen_stmt(n);
    return NULL;
  case ND_LABEL:
    println("%s:", node->unique_label);
    gen_stmt(node->lhs);
    return NULL;
  case ND_RETURN:
    if (node->lhs)
      println("  ret %s %s", gen_type(node->lhs->ty), gen_expr(node->lhs));
    else
      println("  ret void");
    counter++;
    return NULL;
  case ND_EXPR_STMT:
    gen_expr(node->lhs);
    return NULL;
  case ND_ASM:
    error_tok(node->tok, "unavailable statement type");
    return NULL;
  case ND_DO:
  case ND_SWITCH:
  case ND_CASE:
  case ND_GOTO:
  case ND_GOTO_EXPR:
    // TODO
    error_tok(node->tok, "not implemented");
  }

  error_tok(node->tok, "invalid statement");
}

static char *fn_params(Obj *params) {
  char *res = "";
  for (Obj *var = params; var; var = var->next)
    res = format("%s, %s %%arg.%s.%d", res, gen_type(var->ty), var->name, var->offset);
  if (res[0]) res += 2;
  return res;
}

static char *fn_param_types(Type *params) {
  int i = 0;
  char *res = "";
  for (Type *ty = params; ty; ty = ty->next) {
    i++;
    res = format("%s, %s", res, gen_type(ty));
  }
  if (res[0]) res += 2;
  return res;
}

static char *global_value(char *data, Type *ty, int i) {
  uint64_t value;
  switch (ty->size) {
  case 1:
    value = data[i];
    break;
  case 2:
    value = ((uint16_t *)data)[i];
    break;
  case 4:
    value = ((uint32_t *)data)[i];
    break;
  case 8:
    value = ((uint64_t *)data)[i];
    break;
  }
  if (is_flonum(ty)) return format("0x%" PRIx64, value);
  else return format("%" PRIu64, value);
}

static char *global_array(char *data, Type *ty) {
  char *res = "";
  for (int i = 0 ; i < ty->array_len ; i++)
    res = format("%s, %s %s", res, gen_type(ty->base), global_value(data, ty->base, i));
  if (res[0]) res += 2;
  return format("[%s]", res);
}

void codegen(Obj *prog, FILE *out) {
  output_file = out;

  for (Obj *fn = prog; fn; fn = fn->next) {
    if (!fn->is_function) {
      if (fn->ty->kind == TY_ARRAY) {
        println("@%s = global [%d x %s] %s", fn->name, fn->ty->size / fn->ty->base->size, gen_type(fn->ty->base), global_array(fn->init_data, fn->ty));
      } else {
        println("@%s = global %s %s", fn->name, gen_type(fn->ty), global_value(fn->init_data, fn->ty, 0));
      }
      continue;
    }

    int i = 1;
    for (Obj *var = fn->locals; var; var = var->next)
      var->offset = i++;

    if (!fn->is_definition) {
      char *params = fn_param_types(fn->ty->params);
      println("declare %s @%s(%s)", gen_type(fn->ty->return_ty), fn->name, params);
      continue;
    }

    counter = 0;
    counter2 = 0;

    char *params = fn_params(fn->params);
    println("define %s @%s(%s) {", gen_type(fn->ty->return_ty), fn->name, params);

    for (Obj *var = fn->locals; var; var = var->next)
      println("  %%%s.%d = alloca %s", var->name, var->offset, gen_type(var->ty));
    for (Obj *var = fn->params; var; var = var->next)
      println("  store %s %%arg.%s.%d, %s* %%%s.%d", gen_type(var->ty), var->name, var->offset, gen_type(var->ty), var->name, var->offset);

    gen_stmt(fn->body);

    if (!strcmp(fn->name, "main"))
      println("  ret i32 0");
    else if (fn->ty->return_ty->kind == TY_VOID)
      println("  ret void");
    else
      println("  unreachable");

    println("}");
  }
}

This is the file I have been using to test it, and it seems to be compiling and working fine insofar!

test.c
int write(int fd, void *buffer, int len);

void print(char *str)
{
   int l = 0, x;
   while (str[l]) l++;
   while (l) x = write(1, str, l), l -= x, str += x;
}

void put(char ch)
{
   write(1, &ch, 1);
}

char hex(int i)
{
   if (i < 10) return '0' + i;
   else return 'A' + i - 10;
}

int main()
{
   print(" = hex table =\n");
   for (int i = 0 ; i < 0x40 ; i++)
      print("dec "), put(hex(i / 10)), put(hex(i % 10)),
      print(" = hex "), put(hex(i / 0x10)), put(hex(i % 0x10)),
      put('\n');
   return 0;
}
./chibicc -S -o ../test.ll ../test.c
clang -O3 -o ../test ../test.ll
../test

I’m hoping to continue working on this, and I’m going to keep posting updates on this issue, if that’s fine! (I’m going to use edits to avoid pinging watchers, though.)

changelog

2021

  • August 21: Started working on LLVM codegen.
  • October 25: Improved support for global values.

ghost avatar Aug 21 '21 13:08 ghost