diff --git a/README.md b/README.md index 54f86fe..bacec9f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ -# gigchat +# GigaChat API Lua SDK -GigaChat API Lua SDK \ No newline at end of file +Документация API: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/ + +## Использование + +Необходимо переименовать gigachat.test_conf в gigachat.conf и заполнить в нем креды + +``` + +local api = require('gigachat.api') + +-- Получение списка моделей + +api.models() + +-- Запрос к чату (строка или массив сообщений) + +api.chat.completions('тест') + + +``` diff --git a/config/config.lua b/config/config.lua new file mode 100644 index 0000000..78e7c19 --- /dev/null +++ b/config/config.lua @@ -0,0 +1,52 @@ +local json = require('cjson.safe') +local _M = {} +_M.data = {} +_M.comments = {} +_M.file = '' -- файл конфигурации +local key + +function _M.read() + for line in io.lines(_M.file) do + key = string.match(line, '([%w_]+)::') + if (key) then + _M.data[key] = string.match(line, '::(.*) #') + _M.comments[key] = string.match(line, '#(.*)') + if string.find(_M.data[key],'%{%"') or string.find(_M.data[key],'%[%"') then + _M.data[key] = json.decode(_M.data[key]) + end + end + end +end + +function reprint(k,v) + if type(v) == 'table' then + for i,j in pairs(v) do + reprint(i,j) + end + else + print(k..': '..v) + end +end + +function _M.data:write() + local config_file = io.open(_M.file, 'w') + for k,v in pairs(self) do + if type(v) ~= 'function' then + if type(v) == 'table' then v = json.encode(v) end + config_file:write(k..'::'..v..' #'.._M.comments[key]..'\n') + end + end + config_file:close() +end + +function _M.data:print() + for k,v in pairs(self) do + if type(v) ~= 'function' then + if _M.comments[k] then print('\n'.._M.comments[k]:gsub("^%s*(.-)%s*$", "%1")..': \n') end + reprint(k,v) + end + end +end + +return _M + diff --git a/config/gigachat.lua b/config/gigachat.lua new file mode 100644 index 0000000..675210a --- /dev/null +++ b/config/gigachat.lua @@ -0,0 +1,5 @@ +local config = require('config.config') +local _M = config +_M.file = 'gigachat.conf' -- файл конфигурации +_M.read() +return _M.data diff --git a/gigachat.test_conf b/gigachat.test_conf new file mode 100644 index 0000000..3558849 --- /dev/null +++ b/gigachat.test_conf @@ -0,0 +1,8 @@ +auth_key:: #Ключ авторизации +client_id:: #ClientID +scope::GIGACHAT_API_PERS #Scope +token_url::https://ngw.devices.sberbank.ru:9443/api/v2/oauth #Token endpoint +base_url::https://gigachat.devices.sberbank.ru/api/v1/ #Базовый url для информационных ответов +logs_path::/var/www/gigachat/logs/ #Путь к логам +db_path::/var/www/gigachat #Путь к бд хранения токенов + diff --git a/gigachat/api.lua b/gigachat/api.lua new file mode 100644 index 0000000..528393f --- /dev/null +++ b/gigachat/api.lua @@ -0,0 +1,172 @@ +local json = require('cjson') +local cURL = require('cURL') +local log = require('utils.log') +local config = require('config.gigachat') +local uuid = require('utils.uuid') +local flatdb = require('utils.flatdb') +local db = flatdb(config.db_path) + +if not db.token then + db.token = {} + db:save() +end + +local _M = {} + +_M.chat = {} + +log.outfile = config.logs_path..'gigachat_'..os.date('%Y-%m-%d')..'.log' +log.level = 'trace' + +local function poster(data) + local result = {} + for i,k in pairs(data) do table.insert(result, i..'='..k) end + return table.concat(result,'&') +end + +local function get_result(str) + local result, err = pcall(json.decode,str) + if result then + result = json.decode(str) + else + log.error(err) + return nil, err + end + return result,err +end + +-- Получение актуального токена доступа + +function _M.token() + local token = table.remove(db.token, 1) + if token and tonumber(token.expires_at/1000) > os.time() then return token.access_token + else + local str = '' + local headers = { + 'Content-type: application/x-www-form-urlencoded', + 'Accept: application/json', + 'RqUID: '..uuid, + 'Authorization: Basic '..config.auth_key + } + local c = cURL.easy{ + url = config.token_url, + post = true, + postfields = poster({scope=config.scope}), + httpheader = headers, + writefunction = function(st) + str = str..st + collectgarbage("collect") + return #st + end + } + local ok, err = c:perform() + local code = c:getinfo_response_code() + c:close() + if not ok then return nil, err end + if code ~= 200 then + log.error(str) + return nil,str + end + local res,err = get_result(str) + if res then + db.token = {res} + db:save() + return res.access_token + end + end +end + +local function get(endpoint) + local str = '' + local url = config.base_url..endpoint + local headers = { + 'Content-type: application/x-www-form-urlencoded', + 'Accept: application/json', + 'Authorization: Bearer '.._M.token() + } + local c = cURL.easy{ + url = url, + httpheader = headers, + writefunction = function(st) + str = str..st + collectgarbage("collect") + return #st + end + } + local ok, err = c:perform() + local code = c:getinfo_response_code() + c:close() + if not ok then return nil, err end + if code ~= 200 then + log.error(str) + return nil,str + end + local res,err = get_result(str) + if res then return res end + return res,err +end + +local function post(endpoint,data) + local str = '' + local url = config.base_url..endpoint + local headers = { + 'Content-type: application/x-www-form-urlencoded', + 'Accept: application/json', + 'Authorization: Bearer '.._M.token() + } + local c = cURL.easy{ + url = url, + httpheader = headers, + post = true, + postfields = json.encode(data), + writefunction = function(st) + str = str..st + collectgarbage("collect") + return #st + end + } + local ok, err = c:perform() + local code = c:getinfo_response_code() + c:close() + if not ok then return nil, err end + if code ~= 200 then + log.error(str) + return nil,str + end + local res,err = get_result(str) + if res then return res end + return res,err +end + +-- Получение списка доступных моделей + +function _M.models() + local models,err = get('models') + if models and models.data then return models.data end + return models,err +end + +-- Генерация ответа модели + +function _M.chat.completions(messages,model,max_tokens,repetition_penalty,stream,update_interval,temperature,top_p,function_call,functions) + if type(messages) == 'string' then messages = {{role='user',content=messages}} end + if temperature and top_p then return nil, 'Одновременно указаны top_p и температура выборки' end + return post('chat/completions',{ + model = model or 'GigaChat-2', + stream = stream or false, + max_tokens = max_tokens or 512, + repetition_penalty = repetition_penalty or 1, + update_interval = update_interval or 0, + messages = messages + }) +end + +-- Создание эмбеддингов + +function _M.embeddings(model) + return post('embeddings',{ + model = model or 'embeddings' + }) +end + +return _M diff --git a/utils/flatdb.lua b/utils/flatdb.lua new file mode 100644 index 0000000..60916f5 --- /dev/null +++ b/utils/flatdb.lua @@ -0,0 +1,86 @@ +local mp = require("MessagePack") + +local function isFile(path) + local f = io.open(path, "r") + if f then + f:close() + return true + end + return false +end + +local function isDir(path) + path = string.gsub(path.."/", "//", "/") + local ok, err, code = os.rename(path, path) + if ok or code == 13 then + return true + end + return false +end + +local function load_page(path) + local ret + local f = io.open(path, "rb") + if f then + ret = mp.unpack(f:read("*a")) + f:close() + end + return ret +end + +local function store_page(path, page) + if type(page) == "table" then + local f = io.open(path, "wb") + if f then + f:write(mp.pack(page)) + f:close() + return true + end + end + return false +end + +local pool = {} + +local db_funcs = { + save = function(db, p) + if p then + if type(p) == "string" and type(db[p]) == "table" then + return store_page(pool[db].."/"..p, db[p]) + else + return false + end + end + for p, page in pairs(db) do + if not store_page(pool[db].."/"..p, page) then + return false + end + end + return true + end +} + +local mt = { + __index = function(db, k) + if db_funcs[k] then return db_funcs[k] end + if isFile(pool[db].."/"..k) then + db[k] = load_page(pool[db].."/"..k) + end + return rawget(db, k) + end +} + +pool.hack = db_funcs + +return setmetatable(pool, { + __mode = "kv", + __call = function(pool, path) + assert(isDir(path), path.." is not a directory.") + if pool[path] then return pool[path] end + local db = {} + setmetatable(db, mt) + pool[path] = db + pool[db] = path + return db + end +}) diff --git a/utils/log.lua b/utils/log.lua new file mode 100644 index 0000000..d7bc2d4 --- /dev/null +++ b/utils/log.lua @@ -0,0 +1,90 @@ +-- +-- log.lua +-- +-- Copyright (c) 2016 rxi +-- +-- This library is free software; you can redistribute it and/or modify it +-- under the terms of the MIT license. See LICENSE for details. +-- + +local log = { _version = "0.1.0" } + +log.usecolor = true +log.outfile = nil +log.level = "trace" + + +local modes = { + { name = "trace", color = "\27[34m", }, + { name = "debug", color = "\27[36m", }, + { name = "info", color = "\27[32m", }, + { name = "warn", color = "\27[33m", }, + { name = "error", color = "\27[31m", }, + { name = "fatal", color = "\27[35m", }, +} + + +local levels = {} +for i, v in ipairs(modes) do + levels[v.name] = i +end + + +local round = function(x, increment) + increment = increment or 1 + x = x / increment + return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment +end + + +local _tostring = tostring + +local tostring = function(...) + local t = {} + for i = 1, select('#', ...) do + local x = select(i, ...) + if type(x) == "number" then + x = round(x, .01) + end + t[#t + 1] = _tostring(x) + end + return table.concat(t, " ") +end + + +for i, x in ipairs(modes) do + local nameupper = x.name:upper() + log[x.name] = function(...) + + -- Return early if we're below the log level + if i < levels[log.level] then + return + end + + local msg = tostring(...) + local info = debug.getinfo(2, "Sl") + local lineinfo = info.short_src .. ":" .. info.currentline + + -- Output to console + print(string.format("%s[%-6s%s]%s %s: %s", + log.usecolor and x.color or "", + nameupper, + os.date("%H:%M:%S"), + log.usecolor and "\27[0m" or "", + lineinfo, + msg)) + + -- Output to log file + if log.outfile then + local fp = io.open(log.outfile, "a") + local str = string.format("[%-6s%s] %s: %s\n", + nameupper, os.date(), lineinfo, msg) + fp:write(str) + fp:close() + end + + end +end + + +return log diff --git a/utils/uuid.lua b/utils/uuid.lua new file mode 100644 index 0000000..c8289e7 --- /dev/null +++ b/utils/uuid.lua @@ -0,0 +1,10 @@ +local random = math.random +local function uuid() + local template ='xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return string.gsub(template, '[xy]', function (c) + local v = (c == 'x') and random(0, 0xf) or random(8, 0xb) + return string.format('%x', v) + end) +end + +return uuid()