script.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package rds
  2. import (
  3. "github.com/go-redis/redis"
  4. "io/ioutil"
  5. "one.com/kettle/utl"
  6. "os"
  7. "path/filepath"
  8. "strings"
  9. )
  10. /**
  11. 提供对 redis 中 lua 脚本的简单封装
  12. lua 主要用于逻辑层面的批量 key 操作。 这和 multi 命令不同:
  13. 1. lua 脚本是 atomic 的, 可以看作事务。(和DBMS事务不同,主要不支持回滚,这里仅仅是提供了“一次性”的逻辑)
  14. 2. lua 脚本执行不会中途打断。
  15. 3. lua 脚本相比 multi 有更小的数据传数量
  16. 实现目标
  17. 支持 lua 从 lua 文件加载
  18. 支持从字符串加载
  19. 支持从目录加载
  20. 加载后,需要保存 redis 保存脚本后返回的对应的 hash
  21. 調用方可根据文件名调用
  22. 从字符串加载的,可自己设置名字
  23. 名字不能重名
  24. */
  25. type luablocker interface {
  26. Eval(script string, keys []string, args ...interface{}) *redis.Cmd
  27. EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
  28. ScriptExists(hashes ...string) *redis.BoolSliceCmd
  29. ScriptLoad(script string) *redis.StringCmd
  30. }
  31. var _ luablocker = (*redis.Client)(nil)
  32. var _ luablocker = (*redis.Ring)(nil)
  33. var _ luablocker = (*redis.ClusterClient)(nil)
  34. type LuaSrc struct {
  35. code, hash string
  36. }
  37. // LuaBlockManager
  38. type LuaScriptManager struct {
  39. m map[string]*LuaSrc
  40. h luablocker
  41. }
  42. func NewLuaScriptManager(handle luablocker) *LuaScriptManager {
  43. return &LuaScriptManager{
  44. m: make(map[string]*LuaSrc),
  45. h: handle,
  46. }
  47. }
  48. func (self *LuaScriptManager) Exec(name string, keys []string, args ...interface{}) *redis.Cmd {
  49. src, ok := self.m[name]
  50. if !ok {
  51. panic("not load lua script:" + name)
  52. }
  53. return self.h.EvalSha(src.hash, keys, args...)
  54. }
  55. func (self *LuaScriptManager) LoadString(name string, src string) error {
  56. hash, err := self.h.ScriptLoad(src).Result()
  57. if err != nil {
  58. return err
  59. }
  60. self.m[name] = &LuaSrc{
  61. code: src,
  62. hash: hash,
  63. }
  64. return nil
  65. }
  66. func (self *LuaScriptManager) LoadFile(f string) error {
  67. ext := filepath.Ext(f)
  68. if strings.ToLower(ext) != ".lua" {
  69. return utl.ErrParameters
  70. }
  71. data, err := ioutil.ReadFile(f)
  72. if err != nil {
  73. return err
  74. }
  75. filename := filepath.Base(f) // xxx.lua
  76. name := strings.Split(filename, ".")[0] //xxx
  77. return self.LoadString(name, string(data))
  78. }
  79. func (self *LuaScriptManager) LoadPath(p string) error {
  80. err := filepath.Walk(p, func(path string, info os.FileInfo, err error) error {
  81. if err != nil {
  82. return err
  83. }
  84. fi, err := os.Stat(path)
  85. if err != nil {
  86. return err
  87. }
  88. if fi.IsDir() {
  89. return nil
  90. }
  91. return self.LoadFile(path)
  92. })
  93. return err
  94. }