1 /**
2 Internal module for pushing and getting _functions and delegates.
3 
4 LuaD allows for pushing of all D function or delegate types with return type and parameter types compatible with LuaD (see $(DPMODULE stack)).
5 
6 For a fixed number of multiple return values, return a $(STDREF typecons,Tuple) or a static array. For a variable number of return values, return $(MREF LuaVariableReturn).
7 
8 As a special case for $(D const(char)[]) parameter types in _functions pushed to Lua, no copy of the string is made when called; take care not to escape such references, they are effectively $(D scope) parameters.
9 When a copy is desired, use $(D char[]) or $(D string), or $(D dup) or $(D idup) the string manually.
10 
11 If a function with the $(D lua_CFunction) signature is encountered, it is pushed directly with no inserted conversions or overhead.
12 
13 Typesafe varargs is supported when pushing _functions to Lua, but as of DMD 2.054, compiler bugs prevent getting delegates with varargs from Lua (use $(DPREF lfunction,LuaFunction) instead).
14 */
15 module luad.conversions.functions;
16 
17 import core.memory;
18 import std.range;
19 import std.string : toStringz;
20 import std.traits;
21 import std.typetuple;
22 
23 import luad.c.all;
24 
25 import luad.stack;
26 
27 private void argsError(lua_State* L, int nargs, int expected)
28 {
29 	lua_Debug debugInfo;
30 	lua_getstack(L, 0, &debugInfo);
31 	lua_getinfo(L, "n", &debugInfo);
32 	luaL_error(L, "call to %s '%s': got %d arguments, expected %d",
33 		debugInfo.namewhat, debugInfo.name, nargs, expected);
34 }
35 
36 template StripHeadQual(T : const(T*))
37 {
38 	alias const(T)* StripHeadQual;
39 }
40 
41 template StripHeadQual(T : const(T[]))
42 {
43 	alias const(T)[] StripHeadQual;
44 }
45 
46 template StripHeadQual(T : immutable(T*))
47 {
48 	alias immutable(T)* StripHeadQual;
49 }
50 
51 template StripHeadQual(T : immutable(T[]))
52 {
53 	alias immutable(T)[] StripHeadQual;
54 }
55 
56 template StripHeadQual(T : T[])
57 {
58 	alias T[] StripHeadQual;
59 }
60 
61 template StripHeadQual(T : T*)
62 {
63 	alias T* StripHeadQual;
64 }
65 
66 template StripHeadQual(T : T[N], size_t N)
67 {
68 	alias T[N] StripHeadQual;
69 }
70 
71 template StripHeadQual(T)
72 {
73 	alias T StripHeadQual;
74 }
75 
76 template FillableParameterTypeTuple(T)
77 {
78 	alias staticMap!(StripHeadQual, ParameterTypeTuple!T) FillableParameterTypeTuple;
79 }
80 
81 template BindableReturnType(T)
82 {
83 	alias StripHeadQual!(ReturnType!T) BindableReturnType;
84 }
85 
86 //Call with or without return value, propagating Exceptions as Lua errors.
87 //This should rather be throwing a userdata with __tostring and a reference to
88 //the thrown exception, as it is now, everything but the error type and message is lost.
89 int callFunction(T)(lua_State* L, T func, ParameterTypeTuple!T args)
90 	if(!is(BindableReturnType!T == const) &&
91 	   !is(BindableReturnType!T == immutable))
92 {
93 	alias BindableReturnType!T RetType;
94 	enum hasReturnValue = !is(RetType == void);
95 
96 	static if(hasReturnValue)
97 		RetType ret;
98 
99 	try
100 	{
101 		static if(hasReturnValue)
102 			ret = func(args);
103 		else
104 			func(args);
105 	}
106 	catch(Exception e)
107 	{
108 		luaL_error(L, "%s", toStringz(e.toString()));
109 	}
110 
111 	static if(hasReturnValue)
112 		return pushReturnValues(L, ret);
113 	else
114 		return 0;
115 }
116 
117 // Ditto, but wrap the try-catch in a nested function because the return value's
118 // declaration and initialization cannot be separated.
119 int callFunction(T)(lua_State* L, T func, ParameterTypeTuple!T args)
120 	if(is(BindableReturnType!T == const) ||
121 	   is(BindableReturnType!T == immutable))
122 {
123 	auto ref call()
124 	{
125 		try
126 			return func(args);
127 		catch(Exception e)
128 			luaL_error(L, "%s", e.toString().toStringz());
129 	}
130 
131 	return pushReturnValues(L, call());
132 }
133 
134 private:
135 
136 // TODO: right now, virtual functions on specialized classes can be called with base classes as 'self', not safe!
137 extern(C) int methodWrapper(T, Class, bool virtual)(lua_State* L)
138 {
139 	alias ParameterTypeTuple!T Args;
140 
141 	//Check arguments
142 	int top = lua_gettop(L);
143 	if(top < Args.length + 1)
144 		argsError(L, top, Args.length + 1);
145 
146 	Class self =  *cast(Class*)luaL_checkudata(L, 1, toStringz(Class.mangleof));
147 
148 	static if(virtual)
149 	{
150 		alias ReturnType!T function(Class, Args) VirtualWrapper;
151 		VirtualWrapper func = cast(VirtualWrapper)lua_touserdata(L, lua_upvalueindex(1));
152 	}
153 	else
154 	{
155 		T func;
156 		func.ptr = cast(void*)self;
157 		func.funcptr = cast(typeof(func.funcptr))lua_touserdata(L, lua_upvalueindex(1));
158 	}
159 
160 	//Assemble arguments
161 	static if(virtual)
162 	{
163 		ParameterTypeTuple!VirtualWrapper allArgs;
164 		allArgs[0] = self;
165 		alias allArgs[1..$] args;
166 	}
167 	else
168 	{
169 		Args allArgs;
170 		alias allArgs args;
171 	}
172 
173 	foreach(i, Arg; Args)
174 		args[i] = getArgument!(T, i)(L, i + 2);
175 
176 	return callFunction!(typeof(func))(L, func, allArgs);
177 }
178 
179 extern(C) int functionWrapper(T)(lua_State* L)
180 {
181 	alias FillableParameterTypeTuple!T Args;
182 
183 	//Check arguments
184 	int top = lua_gettop(L);
185 	if(top < Args.length)
186 		argsError(L, top, Args.length);
187 
188 	//Get function
189 	static if(is(T == function))
190 		T func = cast(T)lua_touserdata(L, lua_upvalueindex(1));
191 	else
192 		T func = *cast(T*)lua_touserdata(L, lua_upvalueindex(1));
193 
194 	//Assemble arguments
195 	Args args;
196 	foreach(i, Arg; Args)
197 		args[i] = getArgument!(T, i)(L, i + 1);
198 
199 	return callFunction!T(L, func, args);
200 }
201 
202 extern(C) int functionCleaner(lua_State* L)
203 {
204 	GC.removeRoot(lua_touserdata(L, 1));
205 	return 0;
206 }
207 
208 public:
209 
210 void pushFunction(T)(lua_State* L, T func) if (isSomeFunction!T)
211 {
212 	static if(is(T == function))
213 		lua_pushlightuserdata(L, func);
214 	else
215 	{
216 		T* udata = cast(T*)lua_newuserdata(L, T.sizeof);
217 		*udata = func;
218 
219 		GC.addRoot(udata);
220 
221 		if(luaL_newmetatable(L, "__dcall") == 1)
222 		{
223 			lua_pushcfunction(L, &functionCleaner);
224 			lua_setfield(L, -2, "__gc");
225 		}
226 
227 		lua_setmetatable(L, -2);
228 	}
229 
230 	lua_pushcclosure(L, &functionWrapper!T, 1);
231 }
232 
233 // TODO: optimize for non-virtual functions
234 void pushMethod(Class, string member)(lua_State* L) if (isSomeFunction!(__traits(getMember, Class, member)))
235 {
236 	alias typeof(mixin("&Class.init." ~ member)) T;
237 
238 	// Delay vtable lookup until the right time
239 	static ReturnType!T virtualWrapper(Class self, ParameterTypeTuple!T args)
240 	{
241 		return mixin("self." ~ member)(args);
242 	}
243 
244 	lua_pushlightuserdata(L, &virtualWrapper);
245 	lua_pushcclosure(L, &methodWrapper!(T, Class, true), 1);
246 }
247 
248 /**
249  * Currently this function allocates a reference in the registry that is never deleted,
250  * one for each call... see code comments
251  */
252 T getFunction(T)(lua_State* L, int idx) if (is(T == delegate))
253 {
254 	auto func = new class
255 	{
256 		int lref;
257 		this()
258 		{
259 			lua_pushvalue(L, idx);
260 			lref = luaL_ref(L, LUA_REGISTRYINDEX);
261 		}
262 
263 		//Alright... how to fix this?
264 		//The problem is that this object tends to be finalized after L is freed (by LuaState's destructor or otherwise).
265 		//If you have a good solution to the problem of dangling references to a lua_State,
266 		//please contact me :)
267 
268 		/+~this()
269 		{
270 			luaL_unref(L, LUA_REGISTRYINDEX, lref);
271 		}+/
272 
273 		void push()
274 		{
275 			lua_rawgeti(L, LUA_REGISTRYINDEX, lref);
276 		}
277 	};
278 
279 	alias ReturnType!T RetType;
280 	alias ParameterTypeTuple!T Args;
281 
282 	return delegate RetType(Args args)
283 	{
284 		func.push();
285 		foreach(arg; args)
286 			pushValue(L, arg);
287 
288 		return callWithRet!RetType(L, args.length);
289 	};
290 }
291 
292 /**
293  * Type for efficiently returning a variable number of return values
294  * from a function.
295  *
296  * Use $(D variableReturn) to instantiate it.
297  * Params:
298  *   Range = any input range
299  */
300 struct LuaVariableReturn(Range) if(isInputRange!Range)
301 {
302 	alias WrappedType = Range; /// The type of the wrapped input range.
303 	Range returnValues; /// The wrapped input range.
304 }
305 
306 /**
307  * Create a LuaVariableReturn object for efficiently returning
308  * a variable number of values from a function.
309  * Params:
310  *   returnValues = any input range
311  * Example:
312 -----------------------------
313 	LuaVariableReturn!(uint[]) makeList(uint n)
314 	{
315 		uint[] list;
316 
317 		foreach(i; 1 .. n + 1)
318 		{
319 			list ~= i;
320 		}
321 
322 		return variableReturn(list);
323 	}
324 
325 	lua["makeList"] = &makeList;
326 
327 	lua.doString(`
328 		local one, two, three, four = makeList(4)
329 		assert(one == 1)
330 		assert(two == 2)
331 		assert(three == 3)
332 		assert(four == 4)
333 	`);
334 -----------------------------
335  */
336 LuaVariableReturn!Range variableReturn(Range)(Range returnValues)
337 	if(isInputRange!Range)
338 {
339 	return typeof(return)(returnValues);
340 }
341 
342 version(unittest)
343 {
344 	import luad.testing;
345 	import std.typecons;
346 	private lua_State* L;
347 }
348 
349 unittest
350 {
351 	L = luaL_newstate();
352 	luaL_openlibs(L);
353 
354 	//functions
355 	static const(char)[] func(const(char)[] s)
356 	{
357 		return "Hello, " ~ s;
358 	}
359 
360 	pushValue(L, &func);
361 	assert(lua_isfunction(L, -1));
362 	lua_setglobal(L, "sayHello");
363 
364 	unittest_lua(L, `
365 		local ret = sayHello("foo")
366 		local expect = "Hello, foo"
367 		assert(ret == expect,
368 			("sayHello return type - got '%s', expected '%s'"):format(ret, expect)
369 		)
370 	`);
371 
372 	static uint countSpaces(const(char)[] s)
373 	{
374 		uint n = 0;
375 		foreach(dchar c; s)
376 			if(c == ' ')
377 				++n;
378 
379 		return n;
380 	}
381 
382 	pushValue(L, &countSpaces);
383 	assert(lua_isfunction(L, -1));
384 	lua_setglobal(L, "countSpaces");
385 
386 	unittest_lua(L, `
387 		assert(countSpaces("Hello there, world!") == 2)
388 	`);
389 
390 	//delegates
391 	double curry = 3.14 * 2;
392 	double closure(double x)
393 	{
394 		return curry * x;
395 	}
396 
397 	pushValue(L, &closure);
398 	assert(lua_isfunction(L, -1));
399 	lua_setglobal(L, "circle");
400 
401 	unittest_lua(L, `
402 		assert(circle(2) == 3.14 * 4, "closure return type mismatch!")
403 	`);
404 
405 	// Const parameters
406 	static bool isEmpty(const(char[]) str) { return str.length == 0; }
407 	static bool isEmpty2(in char[] str) { return str.length == 0; }
408 
409 	pushValue(L, &isEmpty);
410 	lua_setglobal(L, "isEmpty");
411 
412 	pushValue(L, &isEmpty2);
413 	lua_setglobal(L, "isEmpty2");
414 
415 	unittest_lua(L, `
416 		assert(isEmpty(""))
417 		assert(isEmpty2(""))
418 		assert(not isEmpty("a"))
419 		assert(not isEmpty2("a"))
420 	`);
421 
422 	// Immutable parameters
423 	static immutable(char[]) returnArg(immutable(char[]) str) { return str; }
424 
425 	pushValue(L, &returnArg);
426 	lua_setglobal(L, "returnArg");
427 
428 	unittest_lua(L, `assert(returnArg("foo") == "foo")`);
429 }
430 
431 version(unittest) import luad.base;
432 
433 // multiple return values
434 unittest
435 {
436 	// tuple returns
437 	auto nameInfo = ["foo"];
438 	auto ageInfo = [42];
439 
440 	alias Tuple!(string, "name", uint, "age") GetInfoResult;
441 	GetInfoResult getInfo(int idx)
442 	{
443 		GetInfoResult result;
444 		result.name = nameInfo[idx];
445 		result.age = ageInfo[idx];
446 		return result;
447 	}
448 
449 	pushValue(L, &getInfo);
450 	lua_setglobal(L, "getInfo");
451 
452 	unittest_lua(L, `
453 		local name, age = getInfo(0)
454 		assert(name == "foo")
455 		assert(age == 42)
456 	`);
457 
458 	// static array returns
459 	static string[2] getName()
460 	{
461 		string[2] ret;
462 		ret[0] = "Foo";
463 		ret[1] = "Bar";
464 		return ret;
465 	}
466 
467 	pushValue(L, &getName);
468 	lua_setglobal(L, "getName");
469 
470 	unittest_lua(L, `
471 		local first, last = getName()
472 		assert(first == "Foo")
473 		assert(last == "Bar")
474 	`);
475 
476 	// variable length returns
477 	LuaVariableReturn!(uint[]) makeList(uint n)
478 	{
479 		uint[] list;
480 
481 		foreach(i; 1 .. n + 1)
482 		{
483 			list ~= i;
484 		}
485 
486 		return variableReturn(list);
487 	}
488 
489 	auto makeList2(uint n)
490 	{
491 		return variableReturn(iota(1, n + 1));
492 	}
493 
494 	pushValue(L, &makeList);
495 	lua_setglobal(L, "makeList");
496 	pushValue(L, &makeList2);
497 	lua_setglobal(L, "makeList2");
498 
499 	unittest_lua(L, `
500 		for i, f in pairs{makeList, makeList2} do
501 			local one, two, three, four = f(4)
502 			assert(one == 1)
503 			assert(two == 2)
504 			assert(three == 3)
505 			assert(four == 4)
506 		end
507 	`);
508 }
509 
510 // D-style typesafe varargs
511 unittest
512 {
513 	static string concat(const(char)[][] pieces...)
514 	{
515 		string result;
516 		foreach(piece; pieces)
517 			result ~= piece;
518 		return result;
519 	}
520 
521 	pushValue(L, &concat);
522 	lua_setglobal(L, "concat");
523 
524 	unittest_lua(L, `
525 		local whole = concat("he", "llo", ", ", "world!")
526 		assert(whole == "hello, world!")
527 	`);
528 
529 	static const(char)[] concat2(char separator, const(char)[][] pieces...)
530 	{
531 		if(pieces.length == 0)
532 			return "";
533 
534 		string result;
535 		foreach(piece; pieces[0..$-1])
536 			result ~= piece ~ separator;
537 
538 		return result ~ pieces[$-1];
539 	}
540 
541 	pushValue(L, &concat2);
542 	lua_setglobal(L, "concat2");
543 
544 	unittest_lua(L, `
545 		local whole = concat2(",", "one", "two", "three", "four")
546 		assert(whole == "one,two,three,four")
547 	`);
548 }
549 
550 // get delegates from Lua
551 unittest
552 {
553 	lua_getglobal(L, "string");
554 	lua_getfield(L, -1, "match");
555 	auto match = popValue!(string delegate(string, string))(L);
556 	lua_pop(L, 1);
557 
558 	auto result = match("foobar@example.com", "([^@]+)@example.com");
559 	assert(result == "foobar");
560 
561 	// multiple return values
562 	luaL_dostring(L, `function multRet(a) return "foo", a end`);
563 	lua_getglobal(L, "multRet");
564 	auto multRet = popValue!(Tuple!(string, int) delegate(int))(L);
565 
566 	auto results = multRet(42);
567 	assert(results[0] == "foo");
568 	assert(results[1] == 42);
569 }
570 
571 // Nested call stack testing
572 unittest
573 {
574 	alias string delegate(string) MyFun;
575 
576 	MyFun[string] funcs;
577 
578 	pushValue(L, (string name, MyFun fun) {
579 		funcs[name] = fun;
580 	});
581 	lua_setglobal(L, "addFun");
582 
583 	pushValue(L, (string name, string arg) {
584 		auto top = lua_gettop(L);
585 		auto result = funcs[name](arg);
586 		assert(lua_gettop(L) == top);
587 		return result;
588 	});
589 	lua_setglobal(L, "callFun");
590 
591 	auto top = lua_gettop(L);
592 
593 	luaL_dostring(L, q{
594 		addFun("echo", function(s) return s end)
595 		local result = callFun("echo", "test")
596 		assert(result == "test")
597 	});
598 
599 	assert(lua_gettop(L) == top);
600 }
601 
602 unittest
603 {
604 	assert(lua_gettop(L) == 0);
605 	lua_close(L);
606 }