root/lj_carith.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. carith_checkarg
  2. carith_ptr
  3. carith_int64
  4. lj_carith_meta
  5. lj_carith_op
  6. lj_carith_mul64
  7. lj_carith_divu64
  8. lj_carith_divi64
  9. lj_carith_modu64
  10. lj_carith_modi64
  11. lj_carith_powu64
  12. lj_carith_powi64

   1 /*
   2 ** C data arithmetic.
   3 ** Copyright (C) 2005-2017 Mike Pall. See Copyright Notice in luajit.h
   4 */
   5 
   6 #include "lj_obj.h"
   7 
   8 #if LJ_HASFFI
   9 
  10 #include "lj_gc.h"
  11 #include "lj_err.h"
  12 #include "lj_tab.h"
  13 #include "lj_meta.h"
  14 #include "lj_ctype.h"
  15 #include "lj_cconv.h"
  16 #include "lj_cdata.h"
  17 #include "lj_carith.h"
  18 
  19 /* -- C data arithmetic --------------------------------------------------- */
  20 
  21 /* Binary operands of an operator converted to ctypes. */
  22 typedef struct CDArith {
  23   uint8_t *p[2];
  24   CType *ct[2];
  25 } CDArith;
  26 
  27 /* Check arguments for arithmetic metamethods. */
  28 static int carith_checkarg(lua_State *L, CTState *cts, CDArith *ca)
  29 {
  30   TValue *o = L->base;
  31   int ok = 1;
  32   MSize i;
  33   if (o+1 >= L->top)
  34     lj_err_argt(L, 1, LUA_TCDATA);
  35   for (i = 0; i < 2; i++, o++) {
  36     if (tviscdata(o)) {
  37       GCcdata *cd = cdataV(o);
  38       CTypeID id = (CTypeID)cd->ctypeid;
  39       CType *ct = ctype_raw(cts, id);
  40       uint8_t *p = (uint8_t *)cdataptr(cd);
  41       if (ctype_isptr(ct->info)) {
  42         p = (uint8_t *)cdata_getptr(p, ct->size);
  43         if (ctype_isref(ct->info)) ct = ctype_rawchild(cts, ct);
  44       } else if (ctype_isfunc(ct->info)) {
  45         p = (uint8_t *)*(void **)p;
  46         ct = ctype_get(cts,
  47           lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|id), CTSIZE_PTR));
  48       }
  49       if (ctype_isenum(ct->info)) ct = ctype_child(cts, ct);
  50       ca->ct[i] = ct;
  51       ca->p[i] = p;
  52     } else if (tvisint(o)) {
  53       ca->ct[i] = ctype_get(cts, CTID_INT32);
  54       ca->p[i] = (uint8_t *)&o->i;
  55     } else if (tvisnum(o)) {
  56       ca->ct[i] = ctype_get(cts, CTID_DOUBLE);
  57       ca->p[i] = (uint8_t *)&o->n;
  58     } else if (tvisnil(o)) {
  59       ca->ct[i] = ctype_get(cts, CTID_P_VOID);
  60       ca->p[i] = (uint8_t *)0;
  61     } else if (tvisstr(o)) {
  62       TValue *o2 = i == 0 ? o+1 : o-1;
  63       CType *ct = ctype_raw(cts, cdataV(o2)->ctypeid);
  64       ca->ct[i] = NULL;
  65       ca->p[i] = (uint8_t *)strVdata(o);
  66       ok = 0;
  67       if (ctype_isenum(ct->info)) {
  68         CTSize ofs;
  69         CType *cct = lj_ctype_getfield(cts, ct, strV(o), &ofs);
  70         if (cct && ctype_isconstval(cct->info)) {
  71           ca->ct[i] = ctype_child(cts, cct);
  72           ca->p[i] = (uint8_t *)&cct->size;  /* Assumes ct does not grow. */
  73           ok = 1;
  74         } else {
  75           ca->ct[1-i] = ct;  /* Use enum to improve error message. */
  76           ca->p[1-i] = NULL;
  77           break;
  78         }
  79       }
  80     } else {
  81       ca->ct[i] = NULL;
  82       ca->p[i] = (void *)(intptr_t)1;  /* To make it unequal. */
  83       ok = 0;
  84     }
  85   }
  86   return ok;
  87 }
  88 
  89 /* Pointer arithmetic. */
  90 static int carith_ptr(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
  91 {
  92   CType *ctp = ca->ct[0];
  93   uint8_t *pp = ca->p[0];
  94   ptrdiff_t idx;
  95   CTSize sz;
  96   CTypeID id;
  97   GCcdata *cd;
  98   if (ctype_isptr(ctp->info) || ctype_isrefarray(ctp->info)) {
  99     if ((mm == MM_sub || mm == MM_eq || mm == MM_lt || mm == MM_le) &&
 100         (ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
 101       uint8_t *pp2 = ca->p[1];
 102       if (mm == MM_eq) {  /* Pointer equality. Incompatible pointers are ok. */
 103         setboolV(L->top-1, (pp == pp2));
 104         return 1;
 105       }
 106       if (!lj_cconv_compatptr(cts, ctp, ca->ct[1], CCF_IGNQUAL))
 107         return 0;
 108       if (mm == MM_sub) {  /* Pointer difference. */
 109         intptr_t diff;
 110         sz = lj_ctype_size(cts, ctype_cid(ctp->info));  /* Element size. */
 111         if (sz == 0 || sz == CTSIZE_INVALID)
 112           return 0;
 113         diff = ((intptr_t)pp - (intptr_t)pp2) / (int32_t)sz;
 114         /* All valid pointer differences on x64 are in (-2^47, +2^47),
 115         ** which fits into a double without loss of precision.
 116         */
 117         setintptrV(L->top-1, (int32_t)diff);
 118         return 1;
 119       } else if (mm == MM_lt) {  /* Pointer comparison (unsigned). */
 120         setboolV(L->top-1, ((uintptr_t)pp < (uintptr_t)pp2));
 121         return 1;
 122       } else {
 123         lua_assert(mm == MM_le);
 124         setboolV(L->top-1, ((uintptr_t)pp <= (uintptr_t)pp2));
 125         return 1;
 126       }
 127     }
 128     if (!((mm == MM_add || mm == MM_sub) && ctype_isnum(ca->ct[1]->info)))
 129       return 0;
 130     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[1],
 131                    (uint8_t *)&idx, ca->p[1], 0);
 132     if (mm == MM_sub) idx = -idx;
 133   } else if (mm == MM_add && ctype_isnum(ctp->info) &&
 134       (ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
 135     /* Swap pointer and index. */
 136     ctp = ca->ct[1]; pp = ca->p[1];
 137     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[0],
 138                    (uint8_t *)&idx, ca->p[0], 0);
 139   } else {
 140     return 0;
 141   }
 142   sz = lj_ctype_size(cts, ctype_cid(ctp->info));  /* Element size. */
 143   if (sz == CTSIZE_INVALID)
 144     return 0;
 145   pp += idx*(int32_t)sz;  /* Compute pointer + index. */
 146   id = lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|ctype_cid(ctp->info)),
 147                        CTSIZE_PTR);
 148   cd = lj_cdata_new(cts, id, CTSIZE_PTR);
 149   *(uint8_t **)cdataptr(cd) = pp;
 150   setcdataV(L, L->top-1, cd);
 151   lj_gc_check(L);
 152   return 1;
 153 }
 154 
 155 /* 64 bit integer arithmetic. */
 156 static int carith_int64(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
 157 {
 158   if (ctype_isnum(ca->ct[0]->info) && ca->ct[0]->size <= 8 &&
 159       ctype_isnum(ca->ct[1]->info) && ca->ct[1]->size <= 8) {
 160     CTypeID id = (((ca->ct[0]->info & CTF_UNSIGNED) && ca->ct[0]->size == 8) ||
 161                   ((ca->ct[1]->info & CTF_UNSIGNED) && ca->ct[1]->size == 8)) ?
 162                  CTID_UINT64 : CTID_INT64;
 163     CType *ct = ctype_get(cts, id);
 164     GCcdata *cd;
 165     uint64_t u0, u1, *up;
 166     lj_cconv_ct_ct(cts, ct, ca->ct[0], (uint8_t *)&u0, ca->p[0], 0);
 167     if (mm != MM_unm)
 168       lj_cconv_ct_ct(cts, ct, ca->ct[1], (uint8_t *)&u1, ca->p[1], 0);
 169     switch (mm) {
 170     case MM_eq:
 171       setboolV(L->top-1, (u0 == u1));
 172       return 1;
 173     case MM_lt:
 174       setboolV(L->top-1,
 175                id == CTID_INT64 ? ((int64_t)u0 < (int64_t)u1) : (u0 < u1));
 176       return 1;
 177     case MM_le:
 178       setboolV(L->top-1,
 179                id == CTID_INT64 ? ((int64_t)u0 <= (int64_t)u1) : (u0 <= u1));
 180       return 1;
 181     default: break;
 182     }
 183     cd = lj_cdata_new(cts, id, 8);
 184     up = (uint64_t *)cdataptr(cd);
 185     setcdataV(L, L->top-1, cd);
 186     switch (mm) {
 187     case MM_add: *up = u0 + u1; break;
 188     case MM_sub: *up = u0 - u1; break;
 189     case MM_mul: *up = u0 * u1; break;
 190     case MM_div:
 191       if (id == CTID_INT64)
 192         *up = (uint64_t)lj_carith_divi64((int64_t)u0, (int64_t)u1);
 193       else
 194         *up = lj_carith_divu64(u0, u1);
 195       break;
 196     case MM_mod:
 197       if (id == CTID_INT64)
 198         *up = (uint64_t)lj_carith_modi64((int64_t)u0, (int64_t)u1);
 199       else
 200         *up = lj_carith_modu64(u0, u1);
 201       break;
 202     case MM_pow:
 203       if (id == CTID_INT64)
 204         *up = (uint64_t)lj_carith_powi64((int64_t)u0, (int64_t)u1);
 205       else
 206         *up = lj_carith_powu64(u0, u1);
 207       break;
 208     case MM_unm: *up = (uint64_t)-(int64_t)u0; break;
 209     default: lua_assert(0); break;
 210     }
 211     lj_gc_check(L);
 212     return 1;
 213   }
 214   return 0;
 215 }
 216 
 217 /* Handle ctype arithmetic metamethods. */
 218 static int lj_carith_meta(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
 219 {
 220   cTValue *tv = NULL;
 221   if (tviscdata(L->base)) {
 222     CTypeID id = cdataV(L->base)->ctypeid;
 223     CType *ct = ctype_raw(cts, id);
 224     if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
 225     tv = lj_ctype_meta(cts, id, mm);
 226   }
 227   if (!tv && L->base+1 < L->top && tviscdata(L->base+1)) {
 228     CTypeID id = cdataV(L->base+1)->ctypeid;
 229     CType *ct = ctype_raw(cts, id);
 230     if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
 231     tv = lj_ctype_meta(cts, id, mm);
 232   }
 233   if (!tv) {
 234     const char *repr[2];
 235     int i, isenum = -1, isstr = -1;
 236     if (mm == MM_eq) {  /* Equality checks never raise an error. */
 237       int eq = ca->p[0] == ca->p[1];
 238       setboolV(L->top-1, eq);
 239       setboolV(&G(L)->tmptv2, eq);  /* Remember for trace recorder. */
 240       return 1;
 241     }
 242     for (i = 0; i < 2; i++) {
 243       if (ca->ct[i] && tviscdata(L->base+i)) {
 244         if (ctype_isenum(ca->ct[i]->info)) isenum = i;
 245         repr[i] = strdata(lj_ctype_repr(L, ctype_typeid(cts, ca->ct[i]), NULL));
 246       } else {
 247         if (tvisstr(&L->base[i])) isstr = i;
 248         repr[i] = lj_typename(&L->base[i]);
 249       }
 250     }
 251     if ((isenum ^ isstr) == 1)
 252       lj_err_callerv(L, LJ_ERR_FFI_BADCONV, repr[isstr], repr[isenum]);
 253     lj_err_callerv(L, mm == MM_len ? LJ_ERR_FFI_BADLEN :
 254                       mm == MM_concat ? LJ_ERR_FFI_BADCONCAT :
 255                       mm < MM_add ? LJ_ERR_FFI_BADCOMP : LJ_ERR_FFI_BADARITH,
 256                    repr[0], repr[1]);
 257   }
 258   return lj_meta_tailcall(L, tv);
 259 }
 260 
 261 /* Arithmetic operators for cdata. */
 262 int lj_carith_op(lua_State *L, MMS mm)
 263 {
 264   CTState *cts = ctype_cts(L);
 265   CDArith ca;
 266   if (carith_checkarg(L, cts, &ca)) {
 267     if (carith_int64(L, cts, &ca, mm) || carith_ptr(L, cts, &ca, mm)) {
 268       copyTV(L, &G(L)->tmptv2, L->top-1);  /* Remember for trace recorder. */
 269       return 1;
 270     }
 271   }
 272   return lj_carith_meta(L, cts, &ca, mm);
 273 }
 274 
 275 /* -- 64 bit integer arithmetic helpers ----------------------------------- */
 276 
 277 #if LJ_32 && LJ_HASJIT
 278 /* Signed/unsigned 64 bit multiplication. */
 279 int64_t lj_carith_mul64(int64_t a, int64_t b)
 280 {
 281   return a * b;
 282 }
 283 #endif
 284 
 285 /* Unsigned 64 bit division. */
 286 uint64_t lj_carith_divu64(uint64_t a, uint64_t b)
 287 {
 288   if (b == 0) return U64x(80000000,00000000);
 289   return a / b;
 290 }
 291 
 292 /* Signed 64 bit division. */
 293 int64_t lj_carith_divi64(int64_t a, int64_t b)
 294 {
 295   if (b == 0 || (a == (int64_t)U64x(80000000,00000000) && b == -1))
 296     return U64x(80000000,00000000);
 297   return a / b;
 298 }
 299 
 300 /* Unsigned 64 bit modulo. */
 301 uint64_t lj_carith_modu64(uint64_t a, uint64_t b)
 302 {
 303   if (b == 0) return U64x(80000000,00000000);
 304   return a % b;
 305 }
 306 
 307 /* Signed 64 bit modulo. */
 308 int64_t lj_carith_modi64(int64_t a, int64_t b)
 309 {
 310   if (b == 0) return U64x(80000000,00000000);
 311   if (a == (int64_t)U64x(80000000,00000000) && b == -1) return 0;
 312   return a % b;
 313 }
 314 
 315 /* Unsigned 64 bit x^k. */
 316 uint64_t lj_carith_powu64(uint64_t x, uint64_t k)
 317 {
 318   uint64_t y;
 319   if (k == 0)
 320     return 1;
 321   for (; (k & 1) == 0; k >>= 1) x *= x;
 322   y = x;
 323   if ((k >>= 1) != 0) {
 324     for (;;) {
 325       x *= x;
 326       if (k == 1) break;
 327       if (k & 1) y *= x;
 328       k >>= 1;
 329     }
 330     y *= x;
 331   }
 332   return y;
 333 }
 334 
 335 /* Signed 64 bit x^k. */
 336 int64_t lj_carith_powi64(int64_t x, int64_t k)
 337 {
 338   if (k == 0)
 339     return 1;
 340   if (k < 0) {
 341     if (x == 0)
 342       return U64x(7fffffff,ffffffff);
 343     else if (x == 1)
 344       return 1;
 345     else if (x == -1)
 346       return (k & 1) ? -1 : 1;
 347     else
 348       return 0;
 349   }
 350   return (int64_t)lj_carith_powu64((uint64_t)x, (uint64_t)k);
 351 }
 352 
 353 #endif

/* [<][>][^][v][top][bottom][index][help] */