Skip to content

Commit 64b1d15

Browse files
committed
Deprecate Lua::load_from_finction and replace it with Lua::register_module and Lua::preload_module instead.
1 parent a8a4aa8 commit 64b1d15

File tree

6 files changed

+146
-67
lines changed

6 files changed

+146
-67
lines changed

mlua-sys/src/lua51/compat.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_ch
548548

549549
pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) {
550550
luaL_checkstack(L, 3, cstr!("not enough stack slots available"));
551-
luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED"));
551+
luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE);
552552
if lua_getfield(L, -1, modname) == LUA_TNIL {
553553
lua_pop(L, 1);
554554
lua_pushcfunction(L, openf);

mlua-sys/src/lua51/lauxlib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State};
88
// Extra error code for 'luaL_load'
99
pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1;
1010

11+
// Key, in the registry, for table of loaded modules
12+
pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED");
13+
1114
#[repr(C)]
1215
pub struct luaL_Reg {
1316
pub name: *const c_char,

mlua-sys/src/lua52/compat.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, mut idx: c_int, len: *mut usize)
232232

233233
pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) {
234234
luaL_checkstack(L, 3, cstr!("not enough stack slots available"));
235-
luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED"));
235+
luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE);
236236
if lua_getfield(L, -1, modname) == LUA_TNIL {
237237
lua_pop(L, 1);
238238
lua_pushcfunction(L, openf);

mlua-sys/src/lua52/lauxlib.rs

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State, lua_Un
88
// Extra error code for 'luaL_load'
99
pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1;
1010

11+
// Key, in the registry, for table of loaded modules
12+
pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED");
13+
14+
// Key, in the registry, for table of preloaded loaders
15+
pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD");
16+
1117
#[repr(C)]
1218
pub struct luaL_Reg {
1319
pub name: *const c_char,

src/state.rs

+70-42
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::any::TypeId;
22
use std::cell::{BorrowError, BorrowMutError, RefCell};
33
use std::marker::PhantomData;
44
use std::ops::Deref;
5-
use std::os::raw::c_int;
5+
use std::os::raw::{c_char, c_int};
66
use std::panic::Location;
77
use std::result::Result as StdResult;
88
use std::{fmt, mem, ptr};
@@ -347,40 +347,78 @@ impl Lua {
347347
unsafe { self.lock().load_std_libs(libs) }
348348
}
349349

350-
/// Loads module `modname` into an existing Lua state using the specified entrypoint
351-
/// function.
350+
/// Registers module into an existing Lua state using the specified value.
352351
///
353-
/// Internally calls the Lua function `func` with the string `modname` as an argument,
354-
/// sets the call result to `package.loaded[modname]` and returns copy of the result.
352+
/// After registration, the given value will always be immediately returned when the
353+
/// given module is [required].
355354
///
356-
/// If `package.loaded[modname]` value is not nil, returns copy of the value without
357-
/// calling the function.
355+
/// [required]: https://www.lua.org/manual/5.4/manual.html#pdf-require
356+
pub fn register_module(&self, modname: &str, value: impl IntoLua) -> Result<()> {
357+
#[cfg(not(feature = "luau"))]
358+
const LOADED_MODULES_KEY: *const c_char = ffi::LUA_LOADED_TABLE;
359+
#[cfg(feature = "luau")]
360+
const LOADED_MODULES_KEY: *const c_char = cstr!("_REGISTEREDMODULES");
361+
362+
if cfg!(feature = "luau") && !modname.starts_with('@') {
363+
return Err(Error::runtime("module name must begin with '@'"));
364+
}
365+
unsafe {
366+
self.exec_raw::<()>(value, |state| {
367+
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, LOADED_MODULES_KEY);
368+
ffi::lua_pushlstring(state, modname.as_ptr() as *const c_char, modname.len() as _);
369+
ffi::lua_pushvalue(state, -3);
370+
ffi::lua_rawset(state, -3);
371+
})
372+
}
373+
}
374+
375+
/// Preloads module into an existing Lua state using the specified loader function.
358376
///
359-
/// If the function does not return a non-nil value then this method assigns true to
360-
/// `package.loaded[modname]`.
377+
/// When the module is required, the loader function will be called with module name as the
378+
/// first argument.
361379
///
362-
/// Behavior is similar to Lua's [`require`] function.
380+
/// This is similar to setting the [`package.preload[modname]`] field.
363381
///
364-
/// [`require`]: https://www.lua.org/manual/5.4/manual.html#pdf-require
365-
pub fn load_from_function<T>(&self, modname: &str, func: Function) -> Result<T>
366-
where
367-
T: FromLua,
368-
{
369-
let lua = self.lock();
370-
let state = lua.state();
382+
/// [`package.preload[modname]`]: https://www.lua.org/manual/5.4/manual.html#pdf-package.preload
383+
#[cfg(not(feature = "luau"))]
384+
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
385+
pub fn preload_module(&self, modname: &str, func: Function) -> Result<()> {
386+
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
387+
let preload = unsafe {
388+
self.exec_raw::<Option<Table>>((), |state| {
389+
ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_PRELOAD_TABLE);
390+
})?
391+
};
392+
#[cfg(any(feature = "lua51", feature = "luajit"))]
393+
let preload = unsafe {
394+
self.exec_raw::<Option<Table>>((), |state| {
395+
if ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_LOADED_TABLE) != ffi::LUA_TNIL {
396+
ffi::luaL_getsubtable(state, -1, ffi::LUA_LOADLIBNAME);
397+
ffi::luaL_getsubtable(state, -1, cstr!("preload"));
398+
ffi::lua_rotate(state, 1, 1);
399+
}
400+
})?
401+
};
402+
if let Some(preload) = preload {
403+
preload.raw_set(modname, func)?;
404+
}
405+
Ok(())
406+
}
407+
408+
#[doc(hidden)]
409+
#[deprecated(since = "0.11.0", note = "Use `register_module` instead")]
410+
#[cfg(not(feature = "luau"))]
411+
#[cfg(not(tarpaulin_include))]
412+
pub fn load_from_function<T: FromLua>(&self, modname: &str, func: Function) -> Result<T> {
371413
let loaded = unsafe {
372-
let _sg = StackGuard::new(state);
373-
check_stack(state, 2)?;
374-
protect_lua!(state, 0, 1, fn(state) {
375-
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED"));
376-
})?;
377-
Table(lua.pop_ref())
414+
self.exec_raw::<Table>((), |state| {
415+
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_LOADED_TABLE);
416+
})?
378417
};
379418

380-
let modname = unsafe { lua.create_string(modname)? };
381-
let value = match loaded.raw_get(&modname)? {
419+
let value = match loaded.raw_get(modname)? {
382420
Value::Nil => {
383-
let result = match func.call(&modname)? {
421+
let result = match func.call(modname)? {
384422
Value::Nil => Value::Boolean(true),
385423
res => res,
386424
};
@@ -394,24 +432,14 @@ impl Lua {
394432

395433
/// Unloads module `modname`.
396434
///
397-
/// Removes module from the [`package.loaded`] table which allows to load it again.
398-
/// It does not support unloading binary Lua modules since they are internally cached and can be
399-
/// unloaded only by closing Lua state.
435+
/// This method does not support unloading binary Lua modules since they are internally cached
436+
/// and can be unloaded only by closing Lua state.
437+
///
438+
/// This is similar to calling [`Lua::register_module`] with `Nil` value.
400439
///
401440
/// [`package.loaded`]: https://www.lua.org/manual/5.4/manual.html#pdf-package.loaded
402-
pub fn unload(&self, modname: &str) -> Result<()> {
403-
let lua = self.lock();
404-
let state = lua.state();
405-
let loaded = unsafe {
406-
let _sg = StackGuard::new(state);
407-
check_stack(state, 2)?;
408-
protect_lua!(state, 0, 1, fn(state) {
409-
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED"));
410-
})?;
411-
Table(lua.pop_ref())
412-
};
413-
414-
loaded.raw_set(modname, Nil)
441+
pub fn unload_module(&self, modname: &str) -> Result<()> {
442+
self.register_module(modname, Nil)
415443
}
416444

417445
// Executes module entrypoint function, which returns only one Value.

tests/tests.rs

+65-23
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::collections::HashMap;
33
use std::iter::FromIterator;
44
use std::panic::{catch_unwind, AssertUnwindSafe};
55
use std::string::String as StdString;
6-
use std::sync::atomic::{AtomicU32, Ordering};
76
use std::sync::Arc;
87
use std::{error, f32, f64, fmt};
98

@@ -1168,36 +1167,79 @@ fn test_jit_version() -> Result<()> {
11681167
}
11691168

11701169
#[test]
1171-
fn test_load_from_function() -> Result<()> {
1170+
fn test_register_module() -> Result<()> {
11721171
let lua = Lua::new();
11731172

1174-
let i = Arc::new(AtomicU32::new(0));
1175-
let i2 = i.clone();
1176-
let func = lua.create_function(move |lua, modname: String| {
1177-
i2.fetch_add(1, Ordering::Relaxed);
1173+
let t = lua.create_table()?;
1174+
t.set("name", "my_module")?;
1175+
lua.register_module("@my_module", &t)?;
1176+
1177+
lua.load(
1178+
r#"
1179+
local my_module = require("@my_module")
1180+
assert(my_module.name == "my_module")
1181+
"#,
1182+
)
1183+
.exec()?;
1184+
1185+
lua.unload_module("@my_module")?;
1186+
lua.load(
1187+
r#"
1188+
local ok, err = pcall(function() return require("@my_module") end)
1189+
assert(not ok)
1190+
"#,
1191+
)
1192+
.exec()?;
1193+
1194+
#[cfg(feature = "luau")]
1195+
{
1196+
// Luau registered modules must have '@' prefix
1197+
let res = lua.register_module("my_module", 123);
1198+
assert!(res.is_err());
1199+
assert_eq!(
1200+
res.unwrap_err().to_string(),
1201+
"runtime error: module name must begin with '@'"
1202+
);
1203+
}
1204+
1205+
Ok(())
1206+
}
1207+
1208+
#[test]
1209+
#[cfg(not(feature = "luau"))]
1210+
fn test_preload_module() -> Result<()> {
1211+
let lua = Lua::new();
1212+
1213+
let loader = lua.create_function(move |lua, modname: String| {
11781214
let t = lua.create_table()?;
1179-
t.set("__name", modname)?;
1215+
t.set("name", modname)?;
11801216
Ok(t)
11811217
})?;
11821218

1183-
let t: Table = lua.load_from_function("my_module", func.clone())?;
1184-
assert_eq!(t.get::<String>("__name")?, "my_module");
1185-
assert_eq!(i.load(Ordering::Relaxed), 1);
1186-
1187-
let _: Value = lua.load_from_function("my_module", func.clone())?;
1188-
assert_eq!(i.load(Ordering::Relaxed), 1);
1189-
1190-
let func_nil = lua.create_function(move |_, _: String| Ok(Value::Nil))?;
1191-
let v: Value = lua.load_from_function("my_module2", func_nil)?;
1192-
assert_eq!(v, Value::Boolean(true));
1219+
lua.preload_module("@my_module", loader.clone())?;
1220+
lua.load(
1221+
r#"
1222+
-- `my_module` is global for purposes of next test
1223+
my_module = require("@my_module")
1224+
assert(my_module.name == "@my_module")
1225+
local my_module2 = require("@my_module")
1226+
assert(my_module == my_module2)
1227+
"#,
1228+
)
1229+
.exec()
1230+
.unwrap();
11931231

11941232
// Test unloading and loading again
1195-
lua.unload("my_module")?;
1196-
let _: Value = lua.load_from_function("my_module", func)?;
1197-
assert_eq!(i.load(Ordering::Relaxed), 2);
1198-
1199-
// Unloading nonexistent module must not fail
1200-
lua.unload("my_module2")?;
1233+
lua.unload_module("@my_module")?;
1234+
lua.load(
1235+
r#"
1236+
local my_module3 = require("@my_module")
1237+
-- `my_module` is not equal to `my_module3` because it was reloaded
1238+
assert(my_module ~= my_module3)
1239+
"#,
1240+
)
1241+
.exec()
1242+
.unwrap();
12011243

12021244
Ok(())
12031245
}

0 commit comments

Comments
 (0)