This commit is contained in:
tmddn3070 2023-12-02 15:03:50 +09:00
parent 3605311570
commit 600e6fdc62
15 changed files with 9479 additions and 40 deletions

View File

@ -1,10 +1,11 @@
from discord import ApplicationContext, DiscordException, SlashCommandGroup, Option, Member, File, Attachment, Color, utils, User, Message, TextChannel, Forbidden, HTTPException
from discord.ext.commands import Cog, BucketType, cooldown
import io
from discord import SlashCommandGroup, Option, Member, Color, File, Attachment
from discord.ext.commands import Cog, BucketType, cooldown, guild_only, Context
from Christmas.Database import database
from Christmas.UI.Embed import Aiart_Embed, ChristmasEmbed
from Christmas.UI.Embed import ChristmasEmbed, Aiart_Embed
from Christmas.UI.Modal import Aiart
from Christmas.Tagging import Tagging
from Christmas.Cogs.Event import model
class CAiart(Cog):
def __init__(self, bot):
self.bot = bot
@ -24,8 +25,8 @@ class CAiart(Cog):
nsfw = False
afterprocess = 1 - afterprocess
if ctx.channel.is_nsfw(): nsfw = True
if shows == "보여주기": shows = True
else: shows = False
if shows == "보여주기": shows = False
else: shows = True
if ress == "1:1": resoultion = [512,512]
elif ress == "2:3": resoultion = [512,768]
elif ress == "7:4": resoultion = [896,512]
@ -34,8 +35,29 @@ class CAiart(Cog):
modal = Aiart(title="태그를 입력해주세요",allownsfw=nsfw, res=resoultion, style1=style1, style2=style2, afterprocess=afterprocess, show=shows)
await ctx.send_modal(modal)
@ART.command(name="분석", description="그림을 분석합니다.")
@cooldown(1, 10, BucketType.user)
@guild_only()
async def _분석(self, ctx: Context, file: Option(Attachment, name="파일", description="분석할 그림을 업로드해주세요.", required=True)):
await ctx.defer(ephemeral=True)
if not await database.get_guild(ctx.guild.id): return await ctx.respond(embed=ChristmasEmbed(title="❌ 에러!", description="서버가 가입되어있지 않아요! 서버를 가입해주세요!", color=Color.red()),ephemeral=True)
if not file.content_type.startswith("image/"): return await ctx.respond(embed=ChristmasEmbed(title="❌ 에러!", description="그림 파일만 업로드해주세요!", color=Color.red()),ephemeral=True)
buffer = await file.read()
taging = Tagging(model=model)
tag = await taging.predict(buffer)
rating = tag[2]
tags = tag[4]
hangul = {
"general": "건전함",
"sensitive": "매우조금 불건전",
"questionable": "조금 불건전",
"explicit": "매우 불건전"
}
ratings = max(rating, key=rating.get)
rating = hangul[ratings]
sorted_tags = sorted(tags.items(), key=lambda x: x[1], reverse=True)[:8]
# UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte
await ctx.respond(embed=Aiart_Embed.evalate(sorted_tags, rating), ephemeral=True, file=File(fp=io.BytesIO(buffer), filename="image.png"))
def setup(bot):
bot.add_cog(CAiart(bot))

View File

@ -17,7 +17,8 @@ class CMail(Cog):
@cooldown(1, 10, BucketType.user)
@guild_only()
async def _send(self, ctx: Context, member: Option(Member, name="보낼사람", description="편지를 받을 사람을 선택해주세요.", required=True)):
if pendulum.now().year >= 2023 and pendulum.now().month >= 12 and pendulum.now().day >= 24 and pendulum.now().hour >= 11 and pendulum.now().minute >= 59 and pendulum.now().second >= 59:
christmas = pendulum.datetime(2023, 12, 25, 0, 0, 0)
if pendulum.now() > christmas:
await ctx.respond("이미 편지를 보낼수 있는 기간이 지났어요! 받은 편지가 있다면 확인해보세요!", ephemeral=True)
return
if await database.get_mail_user(member.id, ctx.author.id):
@ -28,16 +29,18 @@ class CMail(Cog):
modal = Send_Mail_Modal(reciveuser=member, editmode=True, title="편지 수정하기")
await ctx.send_modal(modal)
return
else:
modal = Send_Mail_Modal(editmode=False, reciveuser=member, title="편지 보내기")
await ctx.send_modal(modal)
@MAIL.command(name="확인", description="받은 편지를 확인합니다.")
@cooldown(1, 10, BucketType.user)
async def _check(self, ctx: Context):
#관리자를 제외하고는 2023년 12월 25일 00시 00분 00초 이후부터 확인 가능
#if pendulum.now().year >= 2023 and pendulum.now().month >= 12 and pendulum.now().day >= 25 and pendulum.now().hour >= 00 and pendulum.now().minute >= 00 and pendulum.now().second >= 00:
# await ctx.respond("아직 편지를 확인할수 있는 기간이 아니에요! 조금만 기다려주세요!")
# return
#2023년 12월 25일 00시 00분 00초 이후부터 확인 가능
christmas = pendulum.datetime(2023, 12, 25, 0, 0, 0)
if pendulum.now() < christmas:
await ctx.respond("아직 편지를 확인할수 있는 기간이 아니에요! 조금만 기다려주세요!", ephemeral=True)
return
mails = await database.get_mail(ctx.author.id)
if mails == None:
await ctx.respond(embed=Mail_Embed.mail_notfound(), ephemeral=True)
@ -47,7 +50,6 @@ class CMail(Cog):
embeds = []
for data in mails:
embeds.append(Mail_Embed.mail_page(data))
print(embeds)
paginator = Mail_Paginator(embeds=embeds, senduser=ctx.author, timeout=None)
await ctx.respond(embed=embeds[0], view=paginator, ephemeral=True)
def setup(bot):

View File

@ -1,13 +1,54 @@
from discord import Member, SlashCommandGroup, Option
from discord.utils import basic_autocomplete
from discord import Member, SlashCommandGroup, Option, AutocompleteContext
from discord.ext.commands import Cog, cooldown, BucketType, command, has_permissions, bot_has_permissions, Context, guild_only, bot_has_guild_permissions, check
import wavelink
from Christmas.Database import database
from Christmas.UI.Embed import Music_Embed
async def search_music(ctx: AutocompleteContext):
try:
query_result = await wavelink.Playable.search(str(ctx.value), source=wavelink.TrackSource.YouTube)
data = []
for query in query_result:
if query is None: continue
if len(query.title) > 50: title = query.title[:50]+"..."
else: title = query.title
data.append(title)
return data
except Exception as e:
print(e)
return ["검색에 실패했어요!"]
class CMusic(Cog):
def __init__(self, bot):
self.bot = bot
MUSIC = SlashCommandGroup(name="음악", description="음악을 재생합니다.")
#
@MUSIC.command(name="재생", description="음악을 재생합니다.")
@cooldown(1, 10, BucketType.user)
@guild_only()
async def _play(self, ctx: Context, search: Option(str, name="검색", description="검색할 음악을 입력해주세요.", required=True, autocomplete=basic_autocomplete(search_music))):
if not database.get_guild(ctx.guild.id)["music"] == True: return await ctx.respond(embed=Music_Embed.music_notenable(), ephemeral=True)
if ctx.author.voice == None or ctx.author.voice.channel == None: return await ctx.respond(embed=Music_Embed.author_not_voice(), ephemeral=True)
player = None
if ctx.guild.me.voice == None or ctx.guild.me.voice.channel == None:
player = await ctx.author.voice.channel.connect(cls=wavelink.Player)
elif ctx.author.voice.channel != ctx.guild.me.voice.channel:
await ctx.guild.me.voice.channel.move(ctx.author.voice.channel)
player = await ctx.guild.me.voice.channel.connect(cls=wavelink.Player)
else:
player = wavelink.Node.get_player(ctx.guild.id)
#try:
# query_result = await wavelink.Playable.search(str(search), source=wavelink.TrackSource.YouTube)
# if query_result == None: return await ctx.respond(embed=Music_Embed.music_notfound(), ephemeral=True)
# if player.is_playing:
# await wavelink.Queue.put_wait(query_result[0])
# await ctx.respond(embed=Music_Embed.music_queue(query_result[0].title), ephemeral=True)
# else:
#
# await player.play(query_result[0])
# await ctx.respond(embed=Music_Embed.music_play(query_result[0].title), ephemeral=True)
def setup(bot):
bot.add_cog(CMusic(bot))

View File

@ -3,6 +3,7 @@ from discord.ext.commands import Cog, cooldown, BucketType, has_permissions, gui
from Christmas.UI.Embed import Default_Embed
from Christmas.Database import database
from Christmas.Module import get_gpuserver_status, Get_Backend_latency
class CUtil(Cog):
def __init__(self, bot):
@ -14,12 +15,18 @@ class CUtil(Cog):
@slash_command(name="서버가입", description="서버에 가입합니다.")
async def _join(self, ctx: Context):
try:
if await database.get_guild(ctx.guild.id): return await ctx.respond(embed=Default_Embed.already_register(), ephemeral=True)
await database.register_guild(ctx.guild.id)
await ctx.respond(embed=Default_Embed.register_sucess())
except Exception as e:
await ctx.respond(embed=Default_Embed.register_failed())
@guild_only()
@cooldown(1, 10, BucketType.user)
@slash_command(name="봇정보", description="봇의 정보를 확인합니다.")
async def _info(self, ctx: Context):
data = await get_gpuserver_status(url=None)
data2 = await Get_Backend_latency()
await ctx.respond(embed=Default_Embed.BotInfo(data, bot=self.bot, APIlatency=data2))
def setup(bot):
bot.add_cog(CUtil(bot))

View File

@ -1,24 +1,43 @@
import random
import wavelink
import onnxruntime as rt
from discord import ApplicationContext, DiscordException, Game, Guild
from discord.ext.commands import Cog
from discord.ext.commands import CommandOnCooldown
from Christmas.UI.Embed import Default_Embed
from Christmas.config import ChristmasConfig
model = rt.InferenceSession("Christmas/Tagging/model.onnx", provider_options="CPUExecutionProvider")
class Event(Cog):
def __init__(self, bot):
self.bot = bot
self.config = ChristmasConfig()
@Cog.listener()
async def on_application_command_error(self, ctx: ApplicationContext, exception: DiscordException) -> None:
print(exception)
if isinstance(exception, CommandOnCooldown):
await ctx.respond(embed=Default_Embed.cooldown(exception.retry_after), ephemeral=True)
@Cog.listener()
async def on_ready(self) -> None:
print("Ready!")
await self.bot.change_presence(activity=Game(name="크리스마스에 함께!"))
global model
model = rt.InferenceSession("Christmas/Tagging/model.onnx", provider_options="CPUExecutionProvider")
print("Model Loaded!")
# connect wavelink
@Cog.listener()
async def on_connect(self) -> None:
await self.bot.wait_until_ready()
nodes = []
for node in self.config.LAVALINK:
nodes.append(wavelink.Node(uri=node["HOST"], password=node["PASSWORD"], identifier=node["IDENTIFIER"]))
await wavelink.Pool.connect(nodes=nodes, client=self.bot)
@Cog.listener()
async def on_guild_join(self, guild: Guild) -> None:
@ -27,5 +46,9 @@ class Event(Cog):
else:
await random.choice(guild.text_channels).send(embed=Default_Embed.guild_join())
#@Cog.listener()
#async def on_wavelink_node_ready(self, node: wavelink.Node):
# print(f"Node {node.id} is ready!")
def setup(bot):
bot.add_cog(Event(bot))

View File

@ -46,8 +46,11 @@ class database:
async def get_instered_mail_edited(send_user_id: int, user_id: int):
try:
conn = await MongoDBClient().connect()
data = await conn.mail.find_one({"_id": user_id, "mails.userid": send_user_id})
return data["mails"][0]["edited"]
data = await conn.mail.find_one({"_id": user_id})
if data == None: return None
for mail in data["mails"]:
if mail["userid"] == send_user_id:
return mail["edited"]
except Exception as e:
print(e)
return None

View File

@ -1,4 +1,5 @@
import io
import time
import PIL
import base64
import re
@ -11,7 +12,7 @@ from discord import File
translator = Translator()
weight = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.8, 0.9]
cfg = [6,7,8,9,10,11]
cfg = [6, 7, 8, 9, 10, 11]
#def check_curse(text: str):
# return korcen_checker.check(text)
def is_korean(string):
@ -20,6 +21,10 @@ def is_korean(string):
return bool(match)
async def process_prompt(prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float):
# prompt에 "남자"가 들어있거나 remove에 "여자"가 들어있으면 man = True
man = False
if "남자" in prompt or "여자" in remove:
man = True
if is_korean(prompt):
prompt = await translator.translate(prompt, dest="en")
prompt = prompt.text
@ -33,6 +38,9 @@ async def process_prompt(prompt: str, remove: str, res: list, isnsfw: bool, styl
prompt = prompt + f"<lora:光影:{style2}>"
if remove != None:
negative_prompt = default_negative + "," + remove
if man == True:
prompt = prompt + "," + "(1boy)"
negative_prompt = negative_prompt + "," + "(1girl)"
add_prompt = random.choice([True, False])
if add_prompt == True:
qprompt = prompt + f"<lora:canistermix1.1:{random.choice(weight)}>"
@ -84,3 +92,37 @@ async def image_to_base64(image) -> str:
async def base64_to_image(base642) -> File:
attachment = File(io.BytesIO(base64.b64decode(base642)), filename="image.png")
return attachment
async def get_gpuserver_status(url) -> Dict:
async with aiohttp.ClientSession() as session:
try:
# latency도 측정
async with session.get("http://172.30.1.49:7860/sdapi/v1/memory", timeout=10) as response:
if response.status == 200:
# latency 측정
#latency = response.headers["X-Response-Time"]
result = await response.json()
memstatus = result["ram"]["used"]
cudamemstatus = result["cuda"]["system"]["used"]
oomcount = result["cuda"]["events"]["oom"]
return {"status": "online", "system_memory_usage": bytes_to_gb(memstatus), "cuda_memory_usage": bytes_to_gb(cudamemstatus), "oom_count": oomcount}
except Exception as e:
return {"status": "offline"}
def bytes_to_gb(bytes: int) -> float:
return round(bytes / 1024 / 1024 / 1024, 2)
async def Get_Backend_latency():
start_time = time.time()
async with aiohttp.ClientSession() as client:
try:
async with client.get("http://172.30.1.49:7860/sdapi/v1/memory", timeout=10) as response:
if response.status_code == 200 or response.status_code == 404:
return round(time.time() - start_time, 2)
else:
return None
except Exception as e:
return None

View File

@ -0,0 +1,110 @@
from Christmas.Cogs.Event import model
import typing
import asyncio
import io
import PIL
from PIL import Image
import numpy as np
import onnxruntime as rt
import pandas as pd
import cv2
class Tagging:
def __init__(self, model, ProviderOptions: typing.Optional[str] = "CPUExecutionProvider"):
self.model_path = "Christmas/Tagging/model.onnx"
self.tag_path = "Christmas/Tagging/tags.csv"
self.model = model
self.general_threshold: typing.Optional[float] = 0.35
self.character_threshold: typing.Optional[float] = 0.85
def make_square(self, img, target_size):
old_size = img.shape[:2]
desired_size = max(old_size)
desired_size = max(desired_size, target_size)
delta_w = desired_size - old_size[1]
delta_h = desired_size - old_size[0]
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
left, right = delta_w // 2, delta_w - (delta_w // 2)
color = [255, 255, 255]
new_im = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)
return new_im
def smart_resize(self, img, size):
# Assumes the image has already gone through make_square
if img.shape[0] > size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
elif img.shape[0] < size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
return img
def load_labels(self, path):
df = pd.read_csv(path)
tag_names = df["name"].tolist()
category = df["category"].tolist()
rating_indexes = [i for i, cat in enumerate(category) if cat == 9]
general_indexes = [i for i, cat in enumerate(category) if cat == 0]
character_indexes = [i for i, cat in enumerate(category) if cat == 4]
return tag_names, rating_indexes, general_indexes, character_indexes
def preprocess_image(self, image):
image = image.convert("RGBA")
new_image = Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, mask=image)
image = new_image.convert("RGB")
image = np.asarray(image)[:, :, ::-1]
return image
def _predict_image(self, image, general_threshold, character_threshold):
"""discord.File의 byteio를 받아서 이미지를 예측합니다."""
image = Image.open(io.BytesIO(image))
tag_names, rating_indexes, general_indexes, character_indexes = self.load_labels(self.tag_path)
model = self.model
_, height, width, _ = model.get_inputs()[0].shape
image = self.preprocess_image(image)
image = self.make_square(image, height)
image = self.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = model.get_inputs()[0].name
label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input_name: image})[0]
labels = list(zip(tag_names, probs[0].astype(float)))
# Extract ratings, general, and character labels
ratings_names = [labels[i] for i in rating_indexes]
general_names = [labels[i] for i in general_indexes]
character_names = [labels[i] for i in character_indexes]
rating = dict(ratings_names)
general_res = {name: prob for name, prob in general_names if prob > general_threshold}
character_res = {name: prob for name, prob in character_names if prob > character_threshold}
general_res_sorted = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
a = ", ".join(general_res_sorted.keys()).replace("_", " ").replace("(", "\(").replace(")", "\)")
c = ", ".join(general_res_sorted.keys())
return a, c, rating, character_res, general_res_sorted
async def predict(self, image, general_threshold: typing.Optional[float] = None, character_threshold: typing.Optional[float] = None):
if general_threshold is None:
general_threshold = self.general_threshold
if character_threshold is None:
character_threshold = self.character_threshold
return await asyncio.to_thread(self._predict_image, image, general_threshold, character_threshold)

Binary file not shown.

9084
Christmas/Tagging/tags.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,15 @@
import wavelink
import typing
import psutil
from typing import Any, Optional
from discord import Embed, Colour, Embed, Member
from discord.types.embed import EmbedType
from datetime import datetime
from korcen import korcen
from Christmas.Module import get_gpuserver_status
start_time = datetime.now()
class ChristmasEmbed(Embed):
def __init__(self, *,
color: int | Colour | None = 0xf4f9ff,
@ -51,6 +57,45 @@ class Default_Embed:
embed.set_footer()
return embed
@staticmethod
def already_register():
embed = ChristmasEmbed(title="❌ 가입 실패!", description="이미 가입된 서버에요!")
embed.add_field(name="안내", value="이미 가입된 서버에요!", inline=False)
embed.set_footer()
return embed
@staticmethod
def cooldown(sec):
return "이 명령어는 " + str(sec) + "초 뒤에 다시 사용할 수 있어요!"
@staticmethod
def BotInfo(gpuserver: typing.Dict[str, typing.Any], bot, APIlatency) -> Embed:
current_time = datetime.now()
uptime = current_time - start_time
uptime = str(uptime).split(".")[0]
"""
gpuserver: [system_memory_usage,cuda_memory_usage,oom_count]
"""
embed = Embed(title="**봇 정보**", description="크돌이의 정보에요!")
embed.add_field(name="**봇 개요**", value=f"봇 ID: {bot.user.id}\n봇 버전: 0.0.9\n가동시간: {str(uptime)}", inline=False)
orin = psutil.virtual_memory().used
orin = orin / 1024 / 1024 / 1024
if gpuserver == None or gpuserver["status"] == "offline":
embed.add_field(name="**봇 자원**", value=f"GPU서버1 메모리 사용량: **오류**\nGPU서버1 CUDA 메모리 사용량: **오류**\nGPU서버1 메모리 오류 횟수: **오류**\n 현재 샤드 메모리 사용량:{round(orin)}GB", inline=False)
else:
mem_usage = gpuserver["system_memory_usage"]
cuda_memory_usage = gpuserver["cuda_memory_usage"]
oom_count = gpuserver["oom_count"]
embed.add_field(name="**봇 자원**", value=f"현재 샤드 메모리 사용량: {round(orin)}GB\n\nGPU서버1 메모리 사용량: {mem_usage}GB/128GB\nGPU서버1 GPU 메모리 사용량: {cuda_memory_usage}GB/80GB\nGPU서버1 메모리 오류 횟수: {oom_count}", inline=False)
embed.add_field(name="**봇 통계**", value=f"🏘️ **{len(bot.guilds)}**개의 서버에서 봇을 사용중이에요!\n🤖 **{len(bot.users)}**명의 유저와 함께하는 중이에요!", inline=False)
if APIlatency is None:
embed.add_field(name="**봇 핑**", value=f"🏓 **디스코드 게이트웨이 핑**: {round(bot.latency * 1000)}ms\n🏓 **AI 게이트웨이 핑**: 오류", inline=False)
else:
embed.add_field(name="**봇 핑**", value=f"🏓 **디스코드 게이트웨이 핑**: {round(bot.latency * 1000)}ms\n🏓 **AI 게이트웨이 핑**: {APIlatency}ms", inline=False)
embed.set_footer()
return embed
class Mail_Embed:
@ -77,7 +122,7 @@ class Mail_Embed:
@staticmethod
def mail_cant_edit():
embed = ChristmasEmbed(title="❌ 편지 수정 실패!", description="편지 수정에 실패했어요!")
embed.add_field(name="안내", value="편지는 한 번 전송하면 한번의 수정 기회 이후에는 취소할 수 없어요!", inline=False)
embed.add_field(name="안내", value="편지는 한 번 전송하면 한번의 수정 기회 이후에는 취소하거나 수정 할 수 없어요!", inline=False)
embed.set_footer()
return embed
@ -144,9 +189,57 @@ class Aiart_Embed:
embed.set_image(url="attachment://image.png")
embed.set_footer()
return embed
@staticmethod
def generating() -> Embed:
embed = ChristmasEmbed(title="그림 생성중...", description="그림을 생성하는 중이에요!")
embed.add_field(name="안내", value="그림 생성에는 최대 2분이 소요될 수 있어요!(크돌이는 돈이 없어요...)", inline=False)
embed.set_footer()
return embed
@staticmethod
def evalate(tags, rating) -> ChristmasEmbed:
if tags == None:
embed = ChristmasEmbed(name="**오류**", value="그림 분석에 실패했어요. 나중에 다시 시도해주세요", inline=False)
return embed
else:
embed = ChristmasEmbed(title="**그림 분석 완료**", description="그림 분석이 완료되었어요!")
embed.set_image(url="attachment://image.png")
embed.add_field(name="**등급**", value=f"{rating}", inline=False)
texts = ""
for tag, score in tags:
percentage = score * 100
texts = texts + f"{tag} : {percentage:.1f}%\n"
embed.add_field(name="**분석된 요소**", value=f"{texts}", inline=False)
embed.set_footer()
return embed
class Music_Embed:
@staticmethod
def music_notenable():
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
embed.add_field(name="안내", value="이 서버에서는 음악을 재생할수 없어요! \n만약 서버의 관리자라면 ``/설정`` 명령어로 음악 기능을 다시 활성화사킬수 있어요!", inline=False)
embed.set_footer()
return embed
@staticmethod
def author_not_voice():
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
embed.add_field(name="안내", value="음악을 재생하려면 음성채널에 들어가야 해요!", inline=False)
embed.set_footer()
return embed
@staticmethod
def author_diffrent_voice():
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
embed.add_field(name="안내", value="이미 다른 채널에 접속되어있어요!", inline=False)
embed.set_footer()
return embed
#@staticmethod
#def music_queue(music: wavelink.Playable):
#음악을 queue에 넣음
#embed = ChristmasEmbed(title="✅ 음악 재생 성공!", description="음악 재생에 성공했어요!")
#embed =

View File

@ -6,6 +6,7 @@ from discord.ui import Modal, InputText
from Christmas.UI.Embed import Mail_Embed, Aiart_Embed
from Christmas.UI.Buttons import Mail_Confirm_Button
from Christmas.Module import process_prompt, post_gpu_server, base64_to_image
from Christmas.config import ChristmasConfig
BLOCKTAG = [
"nsfw",
@ -40,6 +41,7 @@ class Send_Mail_Modal(Modal):
def __init__(self, reciveuser: Member, editmode: bool, *args, **kwargs):
self.reciveuser = reciveuser
self.editmode = editmode
super().__init__(timeout=None, *args, **kwargs)
self.add_item(InputText(label="제목", placeholder="제목을 입력해주세요.", style=InputTextStyle.short, required=True, custom_id="mail_title"))
@ -58,6 +60,7 @@ class Aiart(Modal):
self.res = res
self.style1 = style1
self.style2 = style2
self.config = ChristmasConfig()
self.afterprocess = afterprocess
super().__init__(timeout=None, *args, **kwargs)
@ -81,7 +84,7 @@ class Aiart(Modal):
await interaction.response.send_message(embed=Aiart_Embed.generating(), ephemeral=self.show)
#prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float
payload = await process_prompt(prompt, remove, self.res, self.allownsfw, self.style1, self.style2, self.afterprocess)
result = await post_gpu_server("http://172.30.1.49:7860/sdapi/v1/txt2img", payload)
result = await post_gpu_server(f"{ChristmasConfig.AI()}/sdapi/v1/txt2img", payload)
if result["status"] != True:
return await interaction.edit_original_response(embed=Aiart_Embed.failed_generate())
else:

View File

@ -32,3 +32,11 @@ class ChristmasConfig:
@property
def DATABASE(self):
return self.json["DATABASE"]
@property
def LAVALINK(self):
return self.json["LAVALINKS"]
@property
def AI(self):
return self.json["AI_GATEWAY"]

View File

@ -1,4 +1,5 @@
import os
import mafic
from types import SimpleNamespace
from typing import Any, cast
@ -21,8 +22,6 @@ class Christmas(AutoShardedBot):
def load_cogs(bot) -> None:
for filename in os.listdir("Christmas/Cogs"):
if filename == "__pycache__":
continue
if filename.endswith(".py"):
bot.load_extension(f"Christmas.Cogs.{filename[:-3]}")

View File

@ -5,3 +5,5 @@ korcen
nanoid
pendulum
Wavelink
onnxruntime
psutil