1 /** 2 Internal module for pushing and getting _functions and delegates. 3 4 LuaD allows for pushing of all D function or delegate types with return type and parameter types compatible with LuaD (see $(DPMODULE stack)). 5 6 For a fixed number of multiple return values, return a $(STDREF typecons,Tuple) or a static array. For a variable number of return values, return $(MREF LuaVariableReturn). 7 8 As a special case for $(D const(char)[]) parameter types in _functions pushed to Lua, no copy of the string is made when called; take care not to escape such references, they are effectively $(D scope) parameters. 9 When a copy is desired, use $(D char[]) or $(D string), or $(D dup) or $(D idup) the string manually. 10 11 If a function with the $(D lua_CFunction) signature is encountered, it is pushed directly with no inserted conversions or overhead. 12 13 Typesafe varargs is supported when pushing _functions to Lua, but as of DMD 2.054, compiler bugs prevent getting delegates with varargs from Lua (use $(DPREF lfunction,LuaFunction) instead). 14 */ 15 module luad.conversions.functions; 16 17 import core.memory; 18 import std.range; 19 import std.string : toStringz; 20 import std.traits; 21 import std.typetuple; 22 23 import luad.c.all; 24 25 import luad.stack; 26 27 private void argsError(lua_State* L, int nargs, int expected) 28 { 29 lua_Debug debugInfo; 30 lua_getstack(L, 0, &debugInfo); 31 lua_getinfo(L, "n", &debugInfo); 32 luaL_error(L, "call to %s '%s': got %d arguments, expected %d", 33 debugInfo.namewhat, debugInfo.name, nargs, expected); 34 } 35 36 template StripHeadQual(T : const(T*)) 37 { 38 alias const(T)* StripHeadQual; 39 } 40 41 template StripHeadQual(T : const(T[])) 42 { 43 alias const(T)[] StripHeadQual; 44 } 45 46 template StripHeadQual(T : immutable(T*)) 47 { 48 alias immutable(T)* StripHeadQual; 49 } 50 51 template StripHeadQual(T : immutable(T[])) 52 { 53 alias immutable(T)[] StripHeadQual; 54 } 55 56 template StripHeadQual(T : T[]) 57 { 58 alias T[] StripHeadQual; 59 } 60 61 template StripHeadQual(T : T*) 62 { 63 alias T* StripHeadQual; 64 } 65 66 template StripHeadQual(T : T[N], size_t N) 67 { 68 alias T[N] StripHeadQual; 69 } 70 71 template StripHeadQual(T) 72 { 73 alias T StripHeadQual; 74 } 75 76 template FillableParameterTypeTuple(T) 77 { 78 alias staticMap!(StripHeadQual, ParameterTypeTuple!T) FillableParameterTypeTuple; 79 } 80 81 template BindableReturnType(T) 82 { 83 alias StripHeadQual!(ReturnType!T) BindableReturnType; 84 } 85 86 //Call with or without return value, propagating Exceptions as Lua errors. 87 //This should rather be throwing a userdata with __tostring and a reference to 88 //the thrown exception, as it is now, everything but the error type and message is lost. 89 int callFunction(T)(lua_State* L, T func, ParameterTypeTuple!T args) 90 if(!is(BindableReturnType!T == const) && 91 !is(BindableReturnType!T == immutable)) 92 { 93 alias BindableReturnType!T RetType; 94 enum hasReturnValue = !is(RetType == void); 95 96 static if(hasReturnValue) 97 RetType ret; 98 99 try 100 { 101 static if(hasReturnValue) 102 ret = func(args); 103 else 104 func(args); 105 } 106 catch(Exception e) 107 { 108 luaL_error(L, "%s", toStringz(e.toString())); 109 } 110 111 static if(hasReturnValue) 112 return pushReturnValues(L, ret); 113 else 114 return 0; 115 } 116 117 // Ditto, but wrap the try-catch in a nested function because the return value's 118 // declaration and initialization cannot be separated. 119 int callFunction(T)(lua_State* L, T func, ParameterTypeTuple!T args) 120 if(is(BindableReturnType!T == const) || 121 is(BindableReturnType!T == immutable)) 122 { 123 auto ref call() 124 { 125 try 126 return func(args); 127 catch(Exception e) 128 luaL_error(L, "%s", e.toString().toStringz()); 129 } 130 131 return pushReturnValues(L, call()); 132 } 133 134 private: 135 136 // TODO: right now, virtual functions on specialized classes can be called with base classes as 'self', not safe! 137 extern(C) int methodWrapper(T, Class, bool virtual)(lua_State* L) 138 { 139 alias ParameterTypeTuple!T Args; 140 141 //Check arguments 142 int top = lua_gettop(L); 143 if(top < Args.length + 1) 144 argsError(L, top, Args.length + 1); 145 146 Class self = *cast(Class*)luaL_checkudata(L, 1, toStringz(Class.mangleof)); 147 148 static if(virtual) 149 { 150 alias ReturnType!T function(Class, Args) VirtualWrapper; 151 VirtualWrapper func = cast(VirtualWrapper)lua_touserdata(L, lua_upvalueindex(1)); 152 } 153 else 154 { 155 T func; 156 func.ptr = cast(void*)self; 157 func.funcptr = cast(typeof(func.funcptr))lua_touserdata(L, lua_upvalueindex(1)); 158 } 159 160 //Assemble arguments 161 static if(virtual) 162 { 163 ParameterTypeTuple!VirtualWrapper allArgs; 164 allArgs[0] = self; 165 alias allArgs[1..$] args; 166 } 167 else 168 { 169 Args allArgs; 170 alias allArgs args; 171 } 172 173 foreach(i, Arg; Args) 174 args[i] = getArgument!(T, i)(L, i + 2); 175 176 return callFunction!(typeof(func))(L, func, allArgs); 177 } 178 179 extern(C) int functionWrapper(T)(lua_State* L) 180 { 181 alias FillableParameterTypeTuple!T Args; 182 183 //Check arguments 184 int top = lua_gettop(L); 185 if(top < Args.length) 186 argsError(L, top, Args.length); 187 188 //Get function 189 static if(is(T == function)) 190 T func = cast(T)lua_touserdata(L, lua_upvalueindex(1)); 191 else 192 T func = *cast(T*)lua_touserdata(L, lua_upvalueindex(1)); 193 194 //Assemble arguments 195 Args args; 196 foreach(i, Arg; Args) 197 args[i] = getArgument!(T, i)(L, i + 1); 198 199 return callFunction!T(L, func, args); 200 } 201 202 extern(C) int functionCleaner(lua_State* L) 203 { 204 GC.removeRoot(lua_touserdata(L, 1)); 205 return 0; 206 } 207 208 public: 209 210 void pushFunction(T)(lua_State* L, T func) if (isSomeFunction!T) 211 { 212 static if(is(T == function)) 213 lua_pushlightuserdata(L, func); 214 else 215 { 216 T* udata = cast(T*)lua_newuserdata(L, T.sizeof); 217 *udata = func; 218 219 GC.addRoot(udata); 220 221 if(luaL_newmetatable(L, "__dcall") == 1) 222 { 223 lua_pushcfunction(L, &functionCleaner); 224 lua_setfield(L, -2, "__gc"); 225 } 226 227 lua_setmetatable(L, -2); 228 } 229 230 lua_pushcclosure(L, &functionWrapper!T, 1); 231 } 232 233 // TODO: optimize for non-virtual functions 234 void pushMethod(Class, string member)(lua_State* L) if (isSomeFunction!(__traits(getMember, Class, member))) 235 { 236 alias typeof(mixin("&Class.init." ~ member)) T; 237 238 // Delay vtable lookup until the right time 239 static ReturnType!T virtualWrapper(Class self, ParameterTypeTuple!T args) 240 { 241 return mixin("self." ~ member)(args); 242 } 243 244 lua_pushlightuserdata(L, &virtualWrapper); 245 lua_pushcclosure(L, &methodWrapper!(T, Class, true), 1); 246 } 247 248 /** 249 * Currently this function allocates a reference in the registry that is never deleted, 250 * one for each call... see code comments 251 */ 252 T getFunction(T)(lua_State* L, int idx) if (is(T == delegate)) 253 { 254 auto func = new class 255 { 256 int lref; 257 this() 258 { 259 lua_pushvalue(L, idx); 260 lref = luaL_ref(L, LUA_REGISTRYINDEX); 261 } 262 263 //Alright... how to fix this? 264 //The problem is that this object tends to be finalized after L is freed (by LuaState's destructor or otherwise). 265 //If you have a good solution to the problem of dangling references to a lua_State, 266 //please contact me :) 267 268 /+~this() 269 { 270 luaL_unref(L, LUA_REGISTRYINDEX, lref); 271 }+/ 272 273 void push() 274 { 275 lua_rawgeti(L, LUA_REGISTRYINDEX, lref); 276 } 277 }; 278 279 alias ReturnType!T RetType; 280 alias ParameterTypeTuple!T Args; 281 282 return delegate RetType(Args args) 283 { 284 func.push(); 285 foreach(arg; args) 286 pushValue(L, arg); 287 288 return callWithRet!RetType(L, args.length); 289 }; 290 } 291 292 /** 293 * Type for efficiently returning a variable number of return values 294 * from a function. 295 * 296 * Use $(D variableReturn) to instantiate it. 297 * Params: 298 * Range = any input range 299 */ 300 struct LuaVariableReturn(Range) if(isInputRange!Range) 301 { 302 alias WrappedType = Range; /// The type of the wrapped input range. 303 Range returnValues; /// The wrapped input range. 304 } 305 306 /** 307 * Create a LuaVariableReturn object for efficiently returning 308 * a variable number of values from a function. 309 * Params: 310 * returnValues = any input range 311 * Example: 312 ----------------------------- 313 LuaVariableReturn!(uint[]) makeList(uint n) 314 { 315 uint[] list; 316 317 foreach(i; 1 .. n + 1) 318 { 319 list ~= i; 320 } 321 322 return variableReturn(list); 323 } 324 325 lua["makeList"] = &makeList; 326 327 lua.doString(` 328 local one, two, three, four = makeList(4) 329 assert(one == 1) 330 assert(two == 2) 331 assert(three == 3) 332 assert(four == 4) 333 `); 334 ----------------------------- 335 */ 336 LuaVariableReturn!Range variableReturn(Range)(Range returnValues) 337 if(isInputRange!Range) 338 { 339 return typeof(return)(returnValues); 340 } 341 342 version(unittest) 343 { 344 import luad.testing; 345 import std.typecons; 346 private lua_State* L; 347 } 348 349 unittest 350 { 351 L = luaL_newstate(); 352 luaL_openlibs(L); 353 354 //functions 355 static const(char)[] func(const(char)[] s) 356 { 357 return "Hello, " ~ s; 358 } 359 360 pushValue(L, &func); 361 assert(lua_isfunction(L, -1)); 362 lua_setglobal(L, "sayHello"); 363 364 unittest_lua(L, ` 365 local ret = sayHello("foo") 366 local expect = "Hello, foo" 367 assert(ret == expect, 368 ("sayHello return type - got '%s', expected '%s'"):format(ret, expect) 369 ) 370 `); 371 372 static uint countSpaces(const(char)[] s) 373 { 374 uint n = 0; 375 foreach(dchar c; s) 376 if(c == ' ') 377 ++n; 378 379 return n; 380 } 381 382 pushValue(L, &countSpaces); 383 assert(lua_isfunction(L, -1)); 384 lua_setglobal(L, "countSpaces"); 385 386 unittest_lua(L, ` 387 assert(countSpaces("Hello there, world!") == 2) 388 `); 389 390 //delegates 391 double curry = 3.14 * 2; 392 double closure(double x) 393 { 394 return curry * x; 395 } 396 397 pushValue(L, &closure); 398 assert(lua_isfunction(L, -1)); 399 lua_setglobal(L, "circle"); 400 401 unittest_lua(L, ` 402 assert(circle(2) == 3.14 * 4, "closure return type mismatch!") 403 `); 404 405 // Const parameters 406 static bool isEmpty(const(char[]) str) { return str.length == 0; } 407 static bool isEmpty2(in char[] str) { return str.length == 0; } 408 409 pushValue(L, &isEmpty); 410 lua_setglobal(L, "isEmpty"); 411 412 pushValue(L, &isEmpty2); 413 lua_setglobal(L, "isEmpty2"); 414 415 unittest_lua(L, ` 416 assert(isEmpty("")) 417 assert(isEmpty2("")) 418 assert(not isEmpty("a")) 419 assert(not isEmpty2("a")) 420 `); 421 422 // Immutable parameters 423 static immutable(char[]) returnArg(immutable(char[]) str) { return str; } 424 425 pushValue(L, &returnArg); 426 lua_setglobal(L, "returnArg"); 427 428 unittest_lua(L, `assert(returnArg("foo") == "foo")`); 429 } 430 431 version(unittest) import luad.base; 432 433 // multiple return values 434 unittest 435 { 436 // tuple returns 437 auto nameInfo = ["foo"]; 438 auto ageInfo = [42]; 439 440 alias Tuple!(string, "name", uint, "age") GetInfoResult; 441 GetInfoResult getInfo(int idx) 442 { 443 GetInfoResult result; 444 result.name = nameInfo[idx]; 445 result.age = ageInfo[idx]; 446 return result; 447 } 448 449 pushValue(L, &getInfo); 450 lua_setglobal(L, "getInfo"); 451 452 unittest_lua(L, ` 453 local name, age = getInfo(0) 454 assert(name == "foo") 455 assert(age == 42) 456 `); 457 458 // static array returns 459 static string[2] getName() 460 { 461 string[2] ret; 462 ret[0] = "Foo"; 463 ret[1] = "Bar"; 464 return ret; 465 } 466 467 pushValue(L, &getName); 468 lua_setglobal(L, "getName"); 469 470 unittest_lua(L, ` 471 local first, last = getName() 472 assert(first == "Foo") 473 assert(last == "Bar") 474 `); 475 476 // variable length returns 477 LuaVariableReturn!(uint[]) makeList(uint n) 478 { 479 uint[] list; 480 481 foreach(i; 1 .. n + 1) 482 { 483 list ~= i; 484 } 485 486 return variableReturn(list); 487 } 488 489 auto makeList2(uint n) 490 { 491 return variableReturn(iota(1, n + 1)); 492 } 493 494 pushValue(L, &makeList); 495 lua_setglobal(L, "makeList"); 496 pushValue(L, &makeList2); 497 lua_setglobal(L, "makeList2"); 498 499 unittest_lua(L, ` 500 for i, f in pairs{makeList, makeList2} do 501 local one, two, three, four = f(4) 502 assert(one == 1) 503 assert(two == 2) 504 assert(three == 3) 505 assert(four == 4) 506 end 507 `); 508 } 509 510 // D-style typesafe varargs 511 unittest 512 { 513 static string concat(const(char)[][] pieces...) 514 { 515 string result; 516 foreach(piece; pieces) 517 result ~= piece; 518 return result; 519 } 520 521 pushValue(L, &concat); 522 lua_setglobal(L, "concat"); 523 524 unittest_lua(L, ` 525 local whole = concat("he", "llo", ", ", "world!") 526 assert(whole == "hello, world!") 527 `); 528 529 static const(char)[] concat2(char separator, const(char)[][] pieces...) 530 { 531 if(pieces.length == 0) 532 return ""; 533 534 string result; 535 foreach(piece; pieces[0..$-1]) 536 result ~= piece ~ separator; 537 538 return result ~ pieces[$-1]; 539 } 540 541 pushValue(L, &concat2); 542 lua_setglobal(L, "concat2"); 543 544 unittest_lua(L, ` 545 local whole = concat2(",", "one", "two", "three", "four") 546 assert(whole == "one,two,three,four") 547 `); 548 } 549 550 // get delegates from Lua 551 unittest 552 { 553 lua_getglobal(L, "string"); 554 lua_getfield(L, -1, "match"); 555 auto match = popValue!(string delegate(string, string))(L); 556 lua_pop(L, 1); 557 558 auto result = match("foobar@example.com", "([^@]+)@example.com"); 559 assert(result == "foobar"); 560 561 // multiple return values 562 luaL_dostring(L, `function multRet(a) return "foo", a end`); 563 lua_getglobal(L, "multRet"); 564 auto multRet = popValue!(Tuple!(string, int) delegate(int))(L); 565 566 auto results = multRet(42); 567 assert(results[0] == "foo"); 568 assert(results[1] == 42); 569 } 570 571 // Nested call stack testing 572 unittest 573 { 574 alias string delegate(string) MyFun; 575 576 MyFun[string] funcs; 577 578 pushValue(L, (string name, MyFun fun) { 579 funcs[name] = fun; 580 }); 581 lua_setglobal(L, "addFun"); 582 583 pushValue(L, (string name, string arg) { 584 auto top = lua_gettop(L); 585 auto result = funcs[name](arg); 586 assert(lua_gettop(L) == top); 587 return result; 588 }); 589 lua_setglobal(L, "callFun"); 590 591 auto top = lua_gettop(L); 592 593 luaL_dostring(L, q{ 594 addFun("echo", function(s) return s end) 595 local result = callFun("echo", "test") 596 assert(result == "test") 597 }); 598 599 assert(lua_gettop(L) == top); 600 } 601 602 unittest 603 { 604 assert(lua_gettop(L) == 0); 605 lua_close(L); 606 }