diff --git a/go/calc.go b/go/calc.go index e0e67d4..cc96787 100644 --- a/go/calc.go +++ b/go/calc.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/chzyer/readline" + lua "github.com/yuin/gopher-lua" ) type Calc struct { @@ -17,6 +18,7 @@ type Calc struct { stack *Stack history []string completer readline.AutoCompleter + L *lua.LState } const Help string = `Available commands: @@ -34,10 +36,24 @@ basic operators: + - * / Math operators: ^ power` +// That way I can add custom functions to completion +func GetCompleteCustomFunctions() func(string) []string { + return func(line string) []string { + funcs := []string{} + for luafunc := range LuaFuncs { + funcs = append(funcs, luafunc) + } + return funcs + } +} + func NewCalc() *Calc { c := Calc{stack: NewStack(), debug: false} c.completer = readline.NewPrefixCompleter( + // custom lua functions + readline.PcItemDynamic(GetCompleteCustomFunctions()), + // commands readline.PcItem("dump"), readline.PcItem("reverse"), @@ -102,11 +118,16 @@ func (c *Calc) Eval(line string) { "SqrtPhi", "Ln2", "Log2E", "Ln10", "Log10E"} functions := []string{"sqrt", "remainder", "%", "%-", "%+"} batch := []string{"median", "avg"} + luafuncs := []string{} if line == "" { return } + for luafunc := range LuaFuncs { + luafuncs = append(luafuncs, luafunc) + } + for _, item := range space.Split(line, -1) { num, err := strconv.ParseFloat(item, 64) @@ -135,6 +156,11 @@ func (c *Calc) Eval(line string) { continue } + if contains(luafuncs, item) { + c.luafunc(item) + continue + } + switch item { case "help": fmt.Println(Help) @@ -333,38 +359,22 @@ func (c *Calc) batchfunc(funcname string) { _ = c.Result() } -func contains(s []string, e string) bool { - for _, a := range s { - if a == e { - return true - } - } - return false -} +func (c *Calc) luafunc(funcname string) { + // we may need to put them onto the stack afterwards! + c.stack.Backup() + b := c.stack.Pop() + a := c.stack.Pop() -func const2num(name string) float64 { - switch name { - case "Pi": - return math.Pi - case "Phi": - return math.Phi - case "Sqrt2": - return math.Sqrt2 - case "SqrtE": - return math.SqrtE - case "SqrtPi": - return math.SqrtPi - case "SqrtPhi": - return math.SqrtPhi - case "Ln2": - return math.Ln2 - case "Log2E": - return math.Log2E - case "Ln10": - return math.Ln10 - case "Log10E": - return math.Log10E - default: - return 0 + x, err := CallLuaFunc(c.L, funcname, a, b) + if err != nil { + fmt.Println(err) + c.stack.Push(a) + c.stack.Push(b) + return } + + c.History("%s(%f,%f) = %f", funcname, a, b, x) + c.stack.Push(x) + + c.Result() } diff --git a/go/go.mod b/go/go.mod index 19dd436..486d1d1 100644 --- a/go/go.mod +++ b/go/go.mod @@ -5,5 +5,6 @@ go 1.20 require ( github.com/chzyer/readline v1.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/yuin/gopher-lua v1.1.0 // indirect golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect ) diff --git a/go/go.sum b/go/go.sum index 242b615..1de60da 100644 --- a/go/go.sum +++ b/go/go.sum @@ -4,5 +4,7 @@ github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObk github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE= +github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/go/interpreter.go b/go/interpreter.go new file mode 100644 index 0000000..1ae4e2b --- /dev/null +++ b/go/interpreter.go @@ -0,0 +1,72 @@ +package main + +import ( + "errors" + "fmt" + + lua "github.com/yuin/gopher-lua" +) + +// LUA interpreter, instanciated in main() +var L *lua.LState + +var LuaFuncs map[string]int + +// called from lua to register a 1 arg math function +func RegisterFuncOneArg(L *lua.LState) int { + function := L.ToString(1) + LuaFuncs[function] = 1 + return 1 +} + +// called from lua to register a 1 arg math function +func RegisterFuncTwoArg(L *lua.LState) int { + function := L.ToString(1) + LuaFuncs[function] = 2 + return 1 +} + +func InitLua(L *lua.LState) { + LuaFuncs = map[string]int{} + L.SetGlobal("RegisterFuncOneArg", L.NewFunction(RegisterFuncOneArg)) + L.SetGlobal("RegisterFuncTwoArg", L.NewFunction(RegisterFuncTwoArg)) + + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal("init"), + NRet: 0, + Protect: true, + }); err != nil { + panic(err) + } +} + +func CallLuaFunc(L *lua.LState, funcname string, a float64, b float64) (float64, error) { + if LuaFuncs[funcname] == 1 { + // 1 arg variant + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal(funcname), + NRet: 1, + Protect: true, + }, lua.LNumber(a)); err != nil { + fmt.Println(err) + return 0, err + } + } else { + // 2 arg variant + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal(funcname), + NRet: 1, + Protect: true, + }, lua.LNumber(a), lua.LNumber(b)); err != nil { + return 0, err + } + } + + // get result and cast to float64 + if res, ok := L.Get(-1).(lua.LNumber); ok { + L.Pop(1) + return float64(res), nil + } + + return 0, errors.New("function did not return a float64!") +} diff --git a/go/main.go b/go/main.go index 40a35e0..8e92fde 100644 --- a/go/main.go +++ b/go/main.go @@ -6,6 +6,7 @@ import ( "github.com/chzyer/readline" flag "github.com/spf13/pflag" + lua "github.com/yuin/gopher-lua" ) const VERSION string = "0.0.1" @@ -23,8 +24,7 @@ Options: When is given, batch mode ist automatically enabled. Use this only when working with stdin. E.g.: echo "2 3 4 5" | rpn + -Copyright (c) 2023 T.v.Dein -` +Copyright (c) 2023 T.v.Dein` func main() { calc := NewCalc() @@ -32,11 +32,14 @@ func main() { showversion := false showhelp := false enabledebug := false + configfile := "" flag.BoolVarP(&calc.batch, "batchmode", "b", false, "batch mode") flag.BoolVarP(&enabledebug, "debug", "d", false, "debug mode") flag.BoolVarP(&showversion, "version", "v", false, "show version") flag.BoolVarP(&showhelp, "help", "h", false, "show usage") + flag.StringVarP(&configfile, "config", "c", os.Getenv("HOME")+"/.rpn.lua", "config file (lua format)") + flag.Parse() if showversion { @@ -53,6 +56,18 @@ func main() { calc.ToggleDebug() } + if _, err := os.Stat(configfile); err == nil { + L = lua.NewState() + defer L.Close() + + if err := L.DoFile(configfile); err != nil { + panic(err) + } + + InitLua(L) + calc.L = L + } + rl, err := readline.NewEx(&readline.Config{ Prompt: "\033[31m»\033[0m ", HistoryFile: os.Getenv("HOME") + "/.rpn-history", diff --git a/go/rpn b/go/rpn index 7ee1857..1e85f83 100755 Binary files a/go/rpn and b/go/rpn differ diff --git a/go/test.lua b/go/test.lua new file mode 100644 index 0000000..066f69f --- /dev/null +++ b/go/test.lua @@ -0,0 +1,12 @@ +function add(a,b) + return a + b +end + +function parallelresistance(a,b) + return 1.0 / (a * b) +end + +function init() + RegisterFuncTwoArg("add") + RegisterFuncTwoArg("parallelresistance") +end diff --git a/go/util.go b/go/util.go new file mode 100644 index 0000000..0962bb0 --- /dev/null +++ b/go/util.go @@ -0,0 +1,39 @@ +package main + +import "math" + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +func const2num(name string) float64 { + switch name { + case "Pi": + return math.Pi + case "Phi": + return math.Phi + case "Sqrt2": + return math.Sqrt2 + case "SqrtE": + return math.SqrtE + case "SqrtPi": + return math.SqrtPi + case "SqrtPhi": + return math.SqrtPhi + case "Ln2": + return math.Ln2 + case "Log2E": + return math.Log2E + case "Ln10": + return math.Ln10 + case "Log10E": + return math.Log10E + default: + return 0 + } +}