added lua interpreter support for custom math functions

This commit is contained in:
2023-10-31 19:02:40 +01:00
parent 4ace2b4385
commit e10faf2204
8 changed files with 185 additions and 34 deletions

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/chzyer/readline" "github.com/chzyer/readline"
lua "github.com/yuin/gopher-lua"
) )
type Calc struct { type Calc struct {
@@ -17,6 +18,7 @@ type Calc struct {
stack *Stack stack *Stack
history []string history []string
completer readline.AutoCompleter completer readline.AutoCompleter
L *lua.LState
} }
const Help string = `Available commands: const Help string = `Available commands:
@@ -34,10 +36,24 @@ basic operators: + - * /
Math operators: Math operators:
^ power` ^ power`
// That way I can add custom functions to completion
func GetCompleteCustomFunctions() func(string) []string {
return func(line string) []string {
funcs := []string{}
for luafunc := range LuaFuncs {
funcs = append(funcs, luafunc)
}
return funcs
}
}
func NewCalc() *Calc { func NewCalc() *Calc {
c := Calc{stack: NewStack(), debug: false} c := Calc{stack: NewStack(), debug: false}
c.completer = readline.NewPrefixCompleter( c.completer = readline.NewPrefixCompleter(
// custom lua functions
readline.PcItemDynamic(GetCompleteCustomFunctions()),
// commands // commands
readline.PcItem("dump"), readline.PcItem("dump"),
readline.PcItem("reverse"), readline.PcItem("reverse"),
@@ -102,11 +118,16 @@ func (c *Calc) Eval(line string) {
"SqrtPhi", "Ln2", "Log2E", "Ln10", "Log10E"} "SqrtPhi", "Ln2", "Log2E", "Ln10", "Log10E"}
functions := []string{"sqrt", "remainder", "%", "%-", "%+"} functions := []string{"sqrt", "remainder", "%", "%-", "%+"}
batch := []string{"median", "avg"} batch := []string{"median", "avg"}
luafuncs := []string{}
if line == "" { if line == "" {
return return
} }
for luafunc := range LuaFuncs {
luafuncs = append(luafuncs, luafunc)
}
for _, item := range space.Split(line, -1) { for _, item := range space.Split(line, -1) {
num, err := strconv.ParseFloat(item, 64) num, err := strconv.ParseFloat(item, 64)
@@ -135,6 +156,11 @@ func (c *Calc) Eval(line string) {
continue continue
} }
if contains(luafuncs, item) {
c.luafunc(item)
continue
}
switch item { switch item {
case "help": case "help":
fmt.Println(Help) fmt.Println(Help)
@@ -333,38 +359,22 @@ func (c *Calc) batchfunc(funcname string) {
_ = c.Result() _ = c.Result()
} }
func contains(s []string, e string) bool { func (c *Calc) luafunc(funcname string) {
for _, a := range s { // we may need to put them onto the stack afterwards!
if a == e { c.stack.Backup()
return true b := c.stack.Pop()
} a := c.stack.Pop()
}
return false
}
func const2num(name string) float64 { x, err := CallLuaFunc(c.L, funcname, a, b)
switch name { if err != nil {
case "Pi": fmt.Println(err)
return math.Pi c.stack.Push(a)
case "Phi": c.stack.Push(b)
return math.Phi return
case "Sqrt2":
return math.Sqrt2
case "SqrtE":
return math.SqrtE
case "SqrtPi":
return math.SqrtPi
case "SqrtPhi":
return math.SqrtPhi
case "Ln2":
return math.Ln2
case "Log2E":
return math.Log2E
case "Ln10":
return math.Ln10
case "Log10E":
return math.Log10E
default:
return 0
} }
c.History("%s(%f,%f) = %f", funcname, a, b, x)
c.stack.Push(x)
c.Result()
} }

View File

@@ -5,5 +5,6 @@ go 1.20
require ( require (
github.com/chzyer/readline v1.5.1 // indirect github.com/chzyer/readline v1.5.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect
) )

View File

@@ -4,5 +4,7 @@ github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObk
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE=
github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

72
go/interpreter.go Normal file
View File

@@ -0,0 +1,72 @@
package main
import (
"errors"
"fmt"
lua "github.com/yuin/gopher-lua"
)
// LUA interpreter, instanciated in main()
var L *lua.LState
var LuaFuncs map[string]int
// called from lua to register a 1 arg math function
func RegisterFuncOneArg(L *lua.LState) int {
function := L.ToString(1)
LuaFuncs[function] = 1
return 1
}
// called from lua to register a 1 arg math function
func RegisterFuncTwoArg(L *lua.LState) int {
function := L.ToString(1)
LuaFuncs[function] = 2
return 1
}
func InitLua(L *lua.LState) {
LuaFuncs = map[string]int{}
L.SetGlobal("RegisterFuncOneArg", L.NewFunction(RegisterFuncOneArg))
L.SetGlobal("RegisterFuncTwoArg", L.NewFunction(RegisterFuncTwoArg))
if err := L.CallByParam(lua.P{
Fn: L.GetGlobal("init"),
NRet: 0,
Protect: true,
}); err != nil {
panic(err)
}
}
func CallLuaFunc(L *lua.LState, funcname string, a float64, b float64) (float64, error) {
if LuaFuncs[funcname] == 1 {
// 1 arg variant
if err := L.CallByParam(lua.P{
Fn: L.GetGlobal(funcname),
NRet: 1,
Protect: true,
}, lua.LNumber(a)); err != nil {
fmt.Println(err)
return 0, err
}
} else {
// 2 arg variant
if err := L.CallByParam(lua.P{
Fn: L.GetGlobal(funcname),
NRet: 1,
Protect: true,
}, lua.LNumber(a), lua.LNumber(b)); err != nil {
return 0, err
}
}
// get result and cast to float64
if res, ok := L.Get(-1).(lua.LNumber); ok {
L.Pop(1)
return float64(res), nil
}
return 0, errors.New("function did not return a float64!")
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/chzyer/readline" "github.com/chzyer/readline"
flag "github.com/spf13/pflag" flag "github.com/spf13/pflag"
lua "github.com/yuin/gopher-lua"
) )
const VERSION string = "0.0.1" const VERSION string = "0.0.1"
@@ -23,8 +24,7 @@ Options:
When <operator> is given, batch mode ist automatically enabled. Use When <operator> is given, batch mode ist automatically enabled. Use
this only when working with stdin. E.g.: echo "2 3 4 5" | rpn + this only when working with stdin. E.g.: echo "2 3 4 5" | rpn +
Copyright (c) 2023 T.v.Dein Copyright (c) 2023 T.v.Dein`
`
func main() { func main() {
calc := NewCalc() calc := NewCalc()
@@ -32,11 +32,14 @@ func main() {
showversion := false showversion := false
showhelp := false showhelp := false
enabledebug := false enabledebug := false
configfile := ""
flag.BoolVarP(&calc.batch, "batchmode", "b", false, "batch mode") flag.BoolVarP(&calc.batch, "batchmode", "b", false, "batch mode")
flag.BoolVarP(&enabledebug, "debug", "d", false, "debug mode") flag.BoolVarP(&enabledebug, "debug", "d", false, "debug mode")
flag.BoolVarP(&showversion, "version", "v", false, "show version") flag.BoolVarP(&showversion, "version", "v", false, "show version")
flag.BoolVarP(&showhelp, "help", "h", false, "show usage") flag.BoolVarP(&showhelp, "help", "h", false, "show usage")
flag.StringVarP(&configfile, "config", "c", os.Getenv("HOME")+"/.rpn.lua", "config file (lua format)")
flag.Parse() flag.Parse()
if showversion { if showversion {
@@ -53,6 +56,18 @@ func main() {
calc.ToggleDebug() calc.ToggleDebug()
} }
if _, err := os.Stat(configfile); err == nil {
L = lua.NewState()
defer L.Close()
if err := L.DoFile(configfile); err != nil {
panic(err)
}
InitLua(L)
calc.L = L
}
rl, err := readline.NewEx(&readline.Config{ rl, err := readline.NewEx(&readline.Config{
Prompt: "\033[31m»\033[0m ", Prompt: "\033[31m»\033[0m ",
HistoryFile: os.Getenv("HOME") + "/.rpn-history", HistoryFile: os.Getenv("HOME") + "/.rpn-history",

BIN
go/rpn

Binary file not shown.

12
go/test.lua Normal file
View File

@@ -0,0 +1,12 @@
function add(a,b)
return a + b
end
function parallelresistance(a,b)
return 1.0 / (a * b)
end
function init()
RegisterFuncTwoArg("add")
RegisterFuncTwoArg("parallelresistance")
end

39
go/util.go Normal file
View File

@@ -0,0 +1,39 @@
package main
import "math"
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
func const2num(name string) float64 {
switch name {
case "Pi":
return math.Pi
case "Phi":
return math.Phi
case "Sqrt2":
return math.Sqrt2
case "SqrtE":
return math.SqrtE
case "SqrtPi":
return math.SqrtPi
case "SqrtPhi":
return math.SqrtPhi
case "Ln2":
return math.Ln2
case "Log2E":
return math.Log2E
case "Ln10":
return math.Ln10
case "Log10E":
return math.Log10E
default:
return 0
}
}