0.0.9
This commit is contained in:
parent
3605311570
commit
600e6fdc62
@ -1,10 +1,11 @@
|
|||||||
from discord import ApplicationContext, DiscordException, SlashCommandGroup, Option, Member, File, Attachment, Color, utils, User, Message, TextChannel, Forbidden, HTTPException
|
import io
|
||||||
from discord.ext.commands import Cog, BucketType, cooldown
|
from discord import SlashCommandGroup, Option, Member, Color, File, Attachment
|
||||||
|
from discord.ext.commands import Cog, BucketType, cooldown, guild_only, Context
|
||||||
from Christmas.Database import database
|
from Christmas.Database import database
|
||||||
from Christmas.UI.Embed import Aiart_Embed, ChristmasEmbed
|
from Christmas.UI.Embed import ChristmasEmbed, Aiart_Embed
|
||||||
from Christmas.UI.Modal import Aiart
|
from Christmas.UI.Modal import Aiart
|
||||||
|
from Christmas.Tagging import Tagging
|
||||||
|
from Christmas.Cogs.Event import model
|
||||||
class CAiart(Cog):
|
class CAiart(Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
@ -24,8 +25,8 @@ class CAiart(Cog):
|
|||||||
nsfw = False
|
nsfw = False
|
||||||
afterprocess = 1 - afterprocess
|
afterprocess = 1 - afterprocess
|
||||||
if ctx.channel.is_nsfw(): nsfw = True
|
if ctx.channel.is_nsfw(): nsfw = True
|
||||||
if shows == "보여주기": shows = True
|
if shows == "보여주기": shows = False
|
||||||
else: shows = False
|
else: shows = True
|
||||||
if ress == "1:1": resoultion = [512,512]
|
if ress == "1:1": resoultion = [512,512]
|
||||||
elif ress == "2:3": resoultion = [512,768]
|
elif ress == "2:3": resoultion = [512,768]
|
||||||
elif ress == "7:4": resoultion = [896,512]
|
elif ress == "7:4": resoultion = [896,512]
|
||||||
@ -34,8 +35,29 @@ class CAiart(Cog):
|
|||||||
modal = Aiart(title="태그를 입력해주세요",allownsfw=nsfw, res=resoultion, style1=style1, style2=style2, afterprocess=afterprocess, show=shows)
|
modal = Aiart(title="태그를 입력해주세요",allownsfw=nsfw, res=resoultion, style1=style1, style2=style2, afterprocess=afterprocess, show=shows)
|
||||||
await ctx.send_modal(modal)
|
await ctx.send_modal(modal)
|
||||||
|
|
||||||
|
@ART.command(name="분석", description="그림을 분석합니다.")
|
||||||
|
@cooldown(1, 10, BucketType.user)
|
||||||
|
@guild_only()
|
||||||
|
async def _분석(self, ctx: Context, file: Option(Attachment, name="파일", description="분석할 그림을 업로드해주세요.", required=True)):
|
||||||
|
await ctx.defer(ephemeral=True)
|
||||||
|
if not await database.get_guild(ctx.guild.id): return await ctx.respond(embed=ChristmasEmbed(title="❌ 에러!", description="서버가 가입되어있지 않아요! 서버를 가입해주세요!", color=Color.red()),ephemeral=True)
|
||||||
|
if not file.content_type.startswith("image/"): return await ctx.respond(embed=ChristmasEmbed(title="❌ 에러!", description="그림 파일만 업로드해주세요!", color=Color.red()),ephemeral=True)
|
||||||
|
buffer = await file.read()
|
||||||
|
taging = Tagging(model=model)
|
||||||
|
tag = await taging.predict(buffer)
|
||||||
|
rating = tag[2]
|
||||||
|
tags = tag[4]
|
||||||
|
hangul = {
|
||||||
|
"general": "건전함",
|
||||||
|
"sensitive": "매우조금 불건전",
|
||||||
|
"questionable": "조금 불건전",
|
||||||
|
"explicit": "매우 불건전"
|
||||||
|
}
|
||||||
|
ratings = max(rating, key=rating.get)
|
||||||
|
rating = hangul[ratings]
|
||||||
|
sorted_tags = sorted(tags.items(), key=lambda x: x[1], reverse=True)[:8]
|
||||||
|
# UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte
|
||||||
|
await ctx.respond(embed=Aiart_Embed.evalate(sorted_tags, rating), ephemeral=True, file=File(fp=io.BytesIO(buffer), filename="image.png"))
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(CAiart(bot))
|
bot.add_cog(CAiart(bot))
|
@ -17,7 +17,8 @@ class CMail(Cog):
|
|||||||
@cooldown(1, 10, BucketType.user)
|
@cooldown(1, 10, BucketType.user)
|
||||||
@guild_only()
|
@guild_only()
|
||||||
async def _send(self, ctx: Context, member: Option(Member, name="보낼사람", description="편지를 받을 사람을 선택해주세요.", required=True)):
|
async def _send(self, ctx: Context, member: Option(Member, name="보낼사람", description="편지를 받을 사람을 선택해주세요.", required=True)):
|
||||||
if pendulum.now().year >= 2023 and pendulum.now().month >= 12 and pendulum.now().day >= 24 and pendulum.now().hour >= 11 and pendulum.now().minute >= 59 and pendulum.now().second >= 59:
|
christmas = pendulum.datetime(2023, 12, 25, 0, 0, 0)
|
||||||
|
if pendulum.now() > christmas:
|
||||||
await ctx.respond("이미 편지를 보낼수 있는 기간이 지났어요! 받은 편지가 있다면 확인해보세요!", ephemeral=True)
|
await ctx.respond("이미 편지를 보낼수 있는 기간이 지났어요! 받은 편지가 있다면 확인해보세요!", ephemeral=True)
|
||||||
return
|
return
|
||||||
if await database.get_mail_user(member.id, ctx.author.id):
|
if await database.get_mail_user(member.id, ctx.author.id):
|
||||||
@ -28,16 +29,18 @@ class CMail(Cog):
|
|||||||
modal = Send_Mail_Modal(reciveuser=member, editmode=True, title="편지 수정하기")
|
modal = Send_Mail_Modal(reciveuser=member, editmode=True, title="편지 수정하기")
|
||||||
await ctx.send_modal(modal)
|
await ctx.send_modal(modal)
|
||||||
return
|
return
|
||||||
modal = Send_Mail_Modal(editmode=False, reciveuser=member, title="편지 보내기")
|
else:
|
||||||
await ctx.send_modal(modal)
|
modal = Send_Mail_Modal(editmode=False, reciveuser=member, title="편지 보내기")
|
||||||
|
await ctx.send_modal(modal)
|
||||||
|
|
||||||
@MAIL.command(name="확인", description="받은 편지를 확인합니다.")
|
@MAIL.command(name="확인", description="받은 편지를 확인합니다.")
|
||||||
@cooldown(1, 10, BucketType.user)
|
@cooldown(1, 10, BucketType.user)
|
||||||
async def _check(self, ctx: Context):
|
async def _check(self, ctx: Context):
|
||||||
#관리자를 제외하고는 2023년 12월 25일 00시 00분 00초 이후부터 확인 가능
|
#2023년 12월 25일 00시 00분 00초 이후부터 확인 가능
|
||||||
#if pendulum.now().year >= 2023 and pendulum.now().month >= 12 and pendulum.now().day >= 25 and pendulum.now().hour >= 00 and pendulum.now().minute >= 00 and pendulum.now().second >= 00:
|
christmas = pendulum.datetime(2023, 12, 25, 0, 0, 0)
|
||||||
# await ctx.respond("아직 편지를 확인할수 있는 기간이 아니에요! 조금만 기다려주세요!")
|
if pendulum.now() < christmas:
|
||||||
# return
|
await ctx.respond("아직 편지를 확인할수 있는 기간이 아니에요! 조금만 기다려주세요!", ephemeral=True)
|
||||||
|
return
|
||||||
mails = await database.get_mail(ctx.author.id)
|
mails = await database.get_mail(ctx.author.id)
|
||||||
if mails == None:
|
if mails == None:
|
||||||
await ctx.respond(embed=Mail_Embed.mail_notfound(), ephemeral=True)
|
await ctx.respond(embed=Mail_Embed.mail_notfound(), ephemeral=True)
|
||||||
@ -47,7 +50,6 @@ class CMail(Cog):
|
|||||||
embeds = []
|
embeds = []
|
||||||
for data in mails:
|
for data in mails:
|
||||||
embeds.append(Mail_Embed.mail_page(data))
|
embeds.append(Mail_Embed.mail_page(data))
|
||||||
print(embeds)
|
|
||||||
paginator = Mail_Paginator(embeds=embeds, senduser=ctx.author, timeout=None)
|
paginator = Mail_Paginator(embeds=embeds, senduser=ctx.author, timeout=None)
|
||||||
await ctx.respond(embed=embeds[0], view=paginator, ephemeral=True)
|
await ctx.respond(embed=embeds[0], view=paginator, ephemeral=True)
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
@ -1,13 +1,54 @@
|
|||||||
from discord import Member, SlashCommandGroup, Option
|
from discord.utils import basic_autocomplete
|
||||||
|
from discord import Member, SlashCommandGroup, Option, AutocompleteContext
|
||||||
from discord.ext.commands import Cog, cooldown, BucketType, command, has_permissions, bot_has_permissions, Context, guild_only, bot_has_guild_permissions, check
|
from discord.ext.commands import Cog, cooldown, BucketType, command, has_permissions, bot_has_permissions, Context, guild_only, bot_has_guild_permissions, check
|
||||||
|
import wavelink
|
||||||
|
|
||||||
|
from Christmas.Database import database
|
||||||
|
from Christmas.UI.Embed import Music_Embed
|
||||||
|
async def search_music(ctx: AutocompleteContext):
|
||||||
|
try:
|
||||||
|
query_result = await wavelink.Playable.search(str(ctx.value), source=wavelink.TrackSource.YouTube)
|
||||||
|
data = []
|
||||||
|
for query in query_result:
|
||||||
|
if query is None: continue
|
||||||
|
if len(query.title) > 50: title = query.title[:50]+"..."
|
||||||
|
else: title = query.title
|
||||||
|
data.append(title)
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ["검색에 실패했어요!"]
|
||||||
|
|
||||||
class CMusic(Cog):
|
class CMusic(Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
|
MUSIC = SlashCommandGroup(name="음악", description="음악을 재생합니다.")
|
||||||
|
#
|
||||||
|
@MUSIC.command(name="재생", description="음악을 재생합니다.")
|
||||||
|
@cooldown(1, 10, BucketType.user)
|
||||||
|
@guild_only()
|
||||||
|
async def _play(self, ctx: Context, search: Option(str, name="검색", description="검색할 음악을 입력해주세요.", required=True, autocomplete=basic_autocomplete(search_music))):
|
||||||
|
if not database.get_guild(ctx.guild.id)["music"] == True: return await ctx.respond(embed=Music_Embed.music_notenable(), ephemeral=True)
|
||||||
|
if ctx.author.voice == None or ctx.author.voice.channel == None: return await ctx.respond(embed=Music_Embed.author_not_voice(), ephemeral=True)
|
||||||
|
player = None
|
||||||
|
if ctx.guild.me.voice == None or ctx.guild.me.voice.channel == None:
|
||||||
|
player = await ctx.author.voice.channel.connect(cls=wavelink.Player)
|
||||||
|
elif ctx.author.voice.channel != ctx.guild.me.voice.channel:
|
||||||
|
await ctx.guild.me.voice.channel.move(ctx.author.voice.channel)
|
||||||
|
player = await ctx.guild.me.voice.channel.connect(cls=wavelink.Player)
|
||||||
|
else:
|
||||||
|
player = wavelink.Node.get_player(ctx.guild.id)
|
||||||
|
#try:
|
||||||
|
# query_result = await wavelink.Playable.search(str(search), source=wavelink.TrackSource.YouTube)
|
||||||
|
# if query_result == None: return await ctx.respond(embed=Music_Embed.music_notfound(), ephemeral=True)
|
||||||
|
# if player.is_playing:
|
||||||
|
# await wavelink.Queue.put_wait(query_result[0])
|
||||||
|
# await ctx.respond(embed=Music_Embed.music_queue(query_result[0].title), ephemeral=True)
|
||||||
|
# else:
|
||||||
|
#
|
||||||
|
# await player.play(query_result[0])
|
||||||
|
# await ctx.respond(embed=Music_Embed.music_play(query_result[0].title), ephemeral=True)
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(CMusic(bot))
|
bot.add_cog(CMusic(bot))
|
@ -3,6 +3,7 @@ from discord.ext.commands import Cog, cooldown, BucketType, has_permissions, gui
|
|||||||
|
|
||||||
from Christmas.UI.Embed import Default_Embed
|
from Christmas.UI.Embed import Default_Embed
|
||||||
from Christmas.Database import database
|
from Christmas.Database import database
|
||||||
|
from Christmas.Module import get_gpuserver_status, Get_Backend_latency
|
||||||
|
|
||||||
class CUtil(Cog):
|
class CUtil(Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
@ -14,12 +15,18 @@ class CUtil(Cog):
|
|||||||
@slash_command(name="서버가입", description="서버에 가입합니다.")
|
@slash_command(name="서버가입", description="서버에 가입합니다.")
|
||||||
async def _join(self, ctx: Context):
|
async def _join(self, ctx: Context):
|
||||||
try:
|
try:
|
||||||
|
if await database.get_guild(ctx.guild.id): return await ctx.respond(embed=Default_Embed.already_register(), ephemeral=True)
|
||||||
await database.register_guild(ctx.guild.id)
|
await database.register_guild(ctx.guild.id)
|
||||||
await ctx.respond(embed=Default_Embed.register_sucess())
|
await ctx.respond(embed=Default_Embed.register_sucess())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await ctx.respond(embed=Default_Embed.register_failed())
|
await ctx.respond(embed=Default_Embed.register_failed())
|
||||||
|
|
||||||
|
@guild_only()
|
||||||
|
@cooldown(1, 10, BucketType.user)
|
||||||
|
@slash_command(name="봇정보", description="봇의 정보를 확인합니다.")
|
||||||
|
async def _info(self, ctx: Context):
|
||||||
|
data = await get_gpuserver_status(url=None)
|
||||||
|
data2 = await Get_Backend_latency()
|
||||||
|
await ctx.respond(embed=Default_Embed.BotInfo(data, bot=self.bot, APIlatency=data2))
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(CUtil(bot))
|
bot.add_cog(CUtil(bot))
|
@ -1,24 +1,43 @@
|
|||||||
import random
|
import random
|
||||||
|
import wavelink
|
||||||
|
|
||||||
|
import onnxruntime as rt
|
||||||
|
|
||||||
from discord import ApplicationContext, DiscordException, Game, Guild
|
from discord import ApplicationContext, DiscordException, Game, Guild
|
||||||
from discord.ext.commands import Cog
|
from discord.ext.commands import Cog
|
||||||
|
|
||||||
|
from discord.ext.commands import CommandOnCooldown
|
||||||
from Christmas.UI.Embed import Default_Embed
|
from Christmas.UI.Embed import Default_Embed
|
||||||
|
from Christmas.config import ChristmasConfig
|
||||||
|
|
||||||
|
model = rt.InferenceSession("Christmas/Tagging/model.onnx", provider_options="CPUExecutionProvider")
|
||||||
|
|
||||||
class Event(Cog):
|
class Event(Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.config = ChristmasConfig()
|
||||||
|
|
||||||
@Cog.listener()
|
@Cog.listener()
|
||||||
async def on_application_command_error(self, ctx: ApplicationContext, exception: DiscordException) -> None:
|
async def on_application_command_error(self, ctx: ApplicationContext, exception: DiscordException) -> None:
|
||||||
print(exception)
|
if isinstance(exception, CommandOnCooldown):
|
||||||
|
await ctx.respond(embed=Default_Embed.cooldown(exception.retry_after), ephemeral=True)
|
||||||
|
|
||||||
@Cog.listener()
|
@Cog.listener()
|
||||||
async def on_ready(self) -> None:
|
async def on_ready(self) -> None:
|
||||||
print("Ready!")
|
print("Ready!")
|
||||||
await self.bot.change_presence(activity=Game(name="크리스마스에 함께!"))
|
await self.bot.change_presence(activity=Game(name="크리스마스에 함께!"))
|
||||||
|
global model
|
||||||
|
model = rt.InferenceSession("Christmas/Tagging/model.onnx", provider_options="CPUExecutionProvider")
|
||||||
|
print("Model Loaded!")
|
||||||
|
|
||||||
|
# connect wavelink
|
||||||
|
@Cog.listener()
|
||||||
|
async def on_connect(self) -> None:
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
nodes = []
|
||||||
|
for node in self.config.LAVALINK:
|
||||||
|
nodes.append(wavelink.Node(uri=node["HOST"], password=node["PASSWORD"], identifier=node["IDENTIFIER"]))
|
||||||
|
await wavelink.Pool.connect(nodes=nodes, client=self.bot)
|
||||||
|
|
||||||
@Cog.listener()
|
@Cog.listener()
|
||||||
async def on_guild_join(self, guild: Guild) -> None:
|
async def on_guild_join(self, guild: Guild) -> None:
|
||||||
@ -27,5 +46,9 @@ class Event(Cog):
|
|||||||
else:
|
else:
|
||||||
await random.choice(guild.text_channels).send(embed=Default_Embed.guild_join())
|
await random.choice(guild.text_channels).send(embed=Default_Embed.guild_join())
|
||||||
|
|
||||||
|
#@Cog.listener()
|
||||||
|
#async def on_wavelink_node_ready(self, node: wavelink.Node):
|
||||||
|
# print(f"Node {node.id} is ready!")
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(Event(bot))
|
bot.add_cog(Event(bot))
|
@ -46,8 +46,11 @@ class database:
|
|||||||
async def get_instered_mail_edited(send_user_id: int, user_id: int):
|
async def get_instered_mail_edited(send_user_id: int, user_id: int):
|
||||||
try:
|
try:
|
||||||
conn = await MongoDBClient().connect()
|
conn = await MongoDBClient().connect()
|
||||||
data = await conn.mail.find_one({"_id": user_id, "mails.userid": send_user_id})
|
data = await conn.mail.find_one({"_id": user_id})
|
||||||
return data["mails"][0]["edited"]
|
if data == None: return None
|
||||||
|
for mail in data["mails"]:
|
||||||
|
if mail["userid"] == send_user_id:
|
||||||
|
return mail["edited"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import time
|
||||||
import PIL
|
import PIL
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
@ -11,7 +12,7 @@ from discord import File
|
|||||||
|
|
||||||
translator = Translator()
|
translator = Translator()
|
||||||
weight = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.8, 0.9]
|
weight = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.8, 0.9]
|
||||||
cfg = [6,7,8,9,10,11]
|
cfg = [6, 7, 8, 9, 10, 11]
|
||||||
#def check_curse(text: str):
|
#def check_curse(text: str):
|
||||||
# return korcen_checker.check(text)
|
# return korcen_checker.check(text)
|
||||||
def is_korean(string):
|
def is_korean(string):
|
||||||
@ -20,6 +21,10 @@ def is_korean(string):
|
|||||||
return bool(match)
|
return bool(match)
|
||||||
|
|
||||||
async def process_prompt(prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float):
|
async def process_prompt(prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float):
|
||||||
|
# prompt에 "남자"가 들어있거나 remove에 "여자"가 들어있으면 man = True
|
||||||
|
man = False
|
||||||
|
if "남자" in prompt or "여자" in remove:
|
||||||
|
man = True
|
||||||
if is_korean(prompt):
|
if is_korean(prompt):
|
||||||
prompt = await translator.translate(prompt, dest="en")
|
prompt = await translator.translate(prompt, dest="en")
|
||||||
prompt = prompt.text
|
prompt = prompt.text
|
||||||
@ -33,6 +38,9 @@ async def process_prompt(prompt: str, remove: str, res: list, isnsfw: bool, styl
|
|||||||
prompt = prompt + f"<lora:光影:{style2}>"
|
prompt = prompt + f"<lora:光影:{style2}>"
|
||||||
if remove != None:
|
if remove != None:
|
||||||
negative_prompt = default_negative + "," + remove
|
negative_prompt = default_negative + "," + remove
|
||||||
|
if man == True:
|
||||||
|
prompt = prompt + "," + "(1boy)"
|
||||||
|
negative_prompt = negative_prompt + "," + "(1girl)"
|
||||||
add_prompt = random.choice([True, False])
|
add_prompt = random.choice([True, False])
|
||||||
if add_prompt == True:
|
if add_prompt == True:
|
||||||
qprompt = prompt + f"<lora:canistermix1.1:{random.choice(weight)}>"
|
qprompt = prompt + f"<lora:canistermix1.1:{random.choice(weight)}>"
|
||||||
@ -84,3 +92,37 @@ async def image_to_base64(image) -> str:
|
|||||||
async def base64_to_image(base642) -> File:
|
async def base64_to_image(base642) -> File:
|
||||||
attachment = File(io.BytesIO(base64.b64decode(base642)), filename="image.png")
|
attachment = File(io.BytesIO(base64.b64decode(base642)), filename="image.png")
|
||||||
return attachment
|
return attachment
|
||||||
|
|
||||||
|
async def get_gpuserver_status(url) -> Dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
# latency도 측정
|
||||||
|
|
||||||
|
async with session.get("http://172.30.1.49:7860/sdapi/v1/memory", timeout=10) as response:
|
||||||
|
|
||||||
|
if response.status == 200:
|
||||||
|
# latency 측정
|
||||||
|
#latency = response.headers["X-Response-Time"]
|
||||||
|
result = await response.json()
|
||||||
|
memstatus = result["ram"]["used"]
|
||||||
|
cudamemstatus = result["cuda"]["system"]["used"]
|
||||||
|
oomcount = result["cuda"]["events"]["oom"]
|
||||||
|
return {"status": "online", "system_memory_usage": bytes_to_gb(memstatus), "cuda_memory_usage": bytes_to_gb(cudamemstatus), "oom_count": oomcount}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "offline"}
|
||||||
|
|
||||||
|
def bytes_to_gb(bytes: int) -> float:
|
||||||
|
return round(bytes / 1024 / 1024 / 1024, 2)
|
||||||
|
|
||||||
|
|
||||||
|
async def Get_Backend_latency():
|
||||||
|
start_time = time.time()
|
||||||
|
async with aiohttp.ClientSession() as client:
|
||||||
|
try:
|
||||||
|
async with client.get("http://172.30.1.49:7860/sdapi/v1/memory", timeout=10) as response:
|
||||||
|
if response.status_code == 200 or response.status_code == 404:
|
||||||
|
return round(time.time() - start_time, 2)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
110
Christmas/Tagging/__init__.py
Normal file
110
Christmas/Tagging/__init__.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
from Christmas.Cogs.Event import model
|
||||||
|
|
||||||
|
|
||||||
|
import typing
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as rt
|
||||||
|
import pandas as pd
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Tagging:
|
||||||
|
def __init__(self, model, ProviderOptions: typing.Optional[str] = "CPUExecutionProvider"):
|
||||||
|
self.model_path = "Christmas/Tagging/model.onnx"
|
||||||
|
self.tag_path = "Christmas/Tagging/tags.csv"
|
||||||
|
self.model = model
|
||||||
|
self.general_threshold: typing.Optional[float] = 0.35
|
||||||
|
self.character_threshold: typing.Optional[float] = 0.85
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def make_square(self, img, target_size):
|
||||||
|
old_size = img.shape[:2]
|
||||||
|
desired_size = max(old_size)
|
||||||
|
desired_size = max(desired_size, target_size)
|
||||||
|
|
||||||
|
delta_w = desired_size - old_size[1]
|
||||||
|
delta_h = desired_size - old_size[0]
|
||||||
|
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
||||||
|
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
||||||
|
|
||||||
|
color = [255, 255, 255]
|
||||||
|
new_im = cv2.copyMakeBorder(
|
||||||
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||||||
|
)
|
||||||
|
return new_im
|
||||||
|
|
||||||
|
|
||||||
|
def smart_resize(self, img, size):
|
||||||
|
# Assumes the image has already gone through make_square
|
||||||
|
if img.shape[0] > size:
|
||||||
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||||
|
elif img.shape[0] < size:
|
||||||
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_labels(self, path):
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
tag_names = df["name"].tolist()
|
||||||
|
category = df["category"].tolist()
|
||||||
|
rating_indexes = [i for i, cat in enumerate(category) if cat == 9]
|
||||||
|
general_indexes = [i for i, cat in enumerate(category) if cat == 0]
|
||||||
|
character_indexes = [i for i, cat in enumerate(category) if cat == 4]
|
||||||
|
return tag_names, rating_indexes, general_indexes, character_indexes
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = image.convert("RGBA")
|
||||||
|
new_image = Image.new("RGBA", image.size, "WHITE")
|
||||||
|
new_image.paste(image, mask=image)
|
||||||
|
image = new_image.convert("RGB")
|
||||||
|
image = np.asarray(image)[:, :, ::-1]
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _predict_image(self, image, general_threshold, character_threshold):
|
||||||
|
"""discord.File의 byteio를 받아서 이미지를 예측합니다."""
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
tag_names, rating_indexes, general_indexes, character_indexes = self.load_labels(self.tag_path)
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
_, height, width, _ = model.get_inputs()[0].shape
|
||||||
|
|
||||||
|
image = self.preprocess_image(image)
|
||||||
|
image = self.make_square(image, height)
|
||||||
|
image = self.smart_resize(image, height)
|
||||||
|
image = image.astype(np.float32)
|
||||||
|
image = np.expand_dims(image, 0)
|
||||||
|
|
||||||
|
input_name = model.get_inputs()[0].name
|
||||||
|
label_name = model.get_outputs()[0].name
|
||||||
|
probs = model.run([label_name], {input_name: image})[0]
|
||||||
|
|
||||||
|
labels = list(zip(tag_names, probs[0].astype(float)))
|
||||||
|
|
||||||
|
# Extract ratings, general, and character labels
|
||||||
|
ratings_names = [labels[i] for i in rating_indexes]
|
||||||
|
general_names = [labels[i] for i in general_indexes]
|
||||||
|
character_names = [labels[i] for i in character_indexes]
|
||||||
|
|
||||||
|
rating = dict(ratings_names)
|
||||||
|
general_res = {name: prob for name, prob in general_names if prob > general_threshold}
|
||||||
|
character_res = {name: prob for name, prob in character_names if prob > character_threshold}
|
||||||
|
|
||||||
|
general_res_sorted = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
|
||||||
|
a = ", ".join(general_res_sorted.keys()).replace("_", " ").replace("(", "\(").replace(")", "\)")
|
||||||
|
c = ", ".join(general_res_sorted.keys())
|
||||||
|
|
||||||
|
return a, c, rating, character_res, general_res_sorted
|
||||||
|
|
||||||
|
|
||||||
|
async def predict(self, image, general_threshold: typing.Optional[float] = None, character_threshold: typing.Optional[float] = None):
|
||||||
|
if general_threshold is None:
|
||||||
|
general_threshold = self.general_threshold
|
||||||
|
if character_threshold is None:
|
||||||
|
character_threshold = self.character_threshold
|
||||||
|
return await asyncio.to_thread(self._predict_image, image, general_threshold, character_threshold)
|
BIN
Christmas/Tagging/model.onnx
Normal file
BIN
Christmas/Tagging/model.onnx
Normal file
Binary file not shown.
9084
Christmas/Tagging/tags.csv
Normal file
9084
Christmas/Tagging/tags.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,15 @@
|
|||||||
|
import wavelink
|
||||||
|
import typing
|
||||||
|
import psutil
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from discord import Embed, Colour, Embed, Member
|
from discord import Embed, Colour, Embed, Member
|
||||||
from discord.types.embed import EmbedType
|
from discord.types.embed import EmbedType
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from korcen import korcen
|
from korcen import korcen
|
||||||
|
|
||||||
|
from Christmas.Module import get_gpuserver_status
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
class ChristmasEmbed(Embed):
|
class ChristmasEmbed(Embed):
|
||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
color: int | Colour | None = 0xf4f9ff,
|
color: int | Colour | None = 0xf4f9ff,
|
||||||
@ -51,6 +57,45 @@ class Default_Embed:
|
|||||||
embed.set_footer()
|
embed.set_footer()
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def already_register():
|
||||||
|
embed = ChristmasEmbed(title="❌ 가입 실패!", description="이미 가입된 서버에요!")
|
||||||
|
embed.add_field(name="안내", value="이미 가입된 서버에요!", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cooldown(sec):
|
||||||
|
return "이 명령어는 " + str(sec) + "초 뒤에 다시 사용할 수 있어요!"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def BotInfo(gpuserver: typing.Dict[str, typing.Any], bot, APIlatency) -> Embed:
|
||||||
|
current_time = datetime.now()
|
||||||
|
uptime = current_time - start_time
|
||||||
|
uptime = str(uptime).split(".")[0]
|
||||||
|
"""
|
||||||
|
gpuserver: [system_memory_usage,cuda_memory_usage,oom_count]
|
||||||
|
"""
|
||||||
|
embed = Embed(title="**봇 정보**", description="크돌이의 정보에요!")
|
||||||
|
embed.add_field(name="**봇 개요**", value=f"봇 ID: {bot.user.id}\n봇 버전: 0.0.9\n가동시간: {str(uptime)}", inline=False)
|
||||||
|
orin = psutil.virtual_memory().used
|
||||||
|
orin = orin / 1024 / 1024 / 1024
|
||||||
|
if gpuserver == None or gpuserver["status"] == "offline":
|
||||||
|
embed.add_field(name="**봇 자원**", value=f"GPU서버1 메모리 사용량: **오류**\nGPU서버1 CUDA 메모리 사용량: **오류**\nGPU서버1 메모리 오류 횟수: **오류**\n 현재 샤드 메모리 사용량:{round(orin)}GB", inline=False)
|
||||||
|
else:
|
||||||
|
mem_usage = gpuserver["system_memory_usage"]
|
||||||
|
cuda_memory_usage = gpuserver["cuda_memory_usage"]
|
||||||
|
oom_count = gpuserver["oom_count"]
|
||||||
|
embed.add_field(name="**봇 자원**", value=f"현재 샤드 메모리 사용량: {round(orin)}GB\n\nGPU서버1 메모리 사용량: {mem_usage}GB/128GB\nGPU서버1 GPU 메모리 사용량: {cuda_memory_usage}GB/80GB\nGPU서버1 메모리 오류 횟수: {oom_count}", inline=False)
|
||||||
|
embed.add_field(name="**봇 통계**", value=f"🏘️ **{len(bot.guilds)}**개의 서버에서 봇을 사용중이에요!\n🤖 **{len(bot.users)}**명의 유저와 함께하는 중이에요!", inline=False)
|
||||||
|
if APIlatency is None:
|
||||||
|
embed.add_field(name="**봇 핑**", value=f"🏓 **디스코드 게이트웨이 핑**: {round(bot.latency * 1000)}ms\n🏓 **AI 게이트웨이 핑**: 오류", inline=False)
|
||||||
|
else:
|
||||||
|
embed.add_field(name="**봇 핑**", value=f"🏓 **디스코드 게이트웨이 핑**: {round(bot.latency * 1000)}ms\n🏓 **AI 게이트웨이 핑**: {APIlatency}ms", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
class Mail_Embed:
|
class Mail_Embed:
|
||||||
|
|
||||||
@ -77,7 +122,7 @@ class Mail_Embed:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def mail_cant_edit():
|
def mail_cant_edit():
|
||||||
embed = ChristmasEmbed(title="❌ 편지 수정 실패!", description="편지 수정에 실패했어요!")
|
embed = ChristmasEmbed(title="❌ 편지 수정 실패!", description="편지 수정에 실패했어요!")
|
||||||
embed.add_field(name="안내", value="편지는 한 번 전송하면 한번의 수정 기회 이후에는 취소할 수 없어요!", inline=False)
|
embed.add_field(name="안내", value="편지는 한 번 전송하면 한번의 수정 기회 이후에는 취소하거나 수정 할 수 없어요!", inline=False)
|
||||||
embed.set_footer()
|
embed.set_footer()
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
@ -144,9 +189,57 @@ class Aiart_Embed:
|
|||||||
embed.set_image(url="attachment://image.png")
|
embed.set_image(url="attachment://image.png")
|
||||||
embed.set_footer()
|
embed.set_footer()
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generating() -> Embed:
|
def generating() -> Embed:
|
||||||
embed = ChristmasEmbed(title="그림 생성중...", description="그림을 생성하는 중이에요!")
|
embed = ChristmasEmbed(title="그림 생성중...", description="그림을 생성하는 중이에요!")
|
||||||
embed.add_field(name="안내", value="그림 생성에는 최대 2분이 소요될 수 있어요!(크돌이는 돈이 없어요...)", inline=False)
|
embed.add_field(name="안내", value="그림 생성에는 최대 2분이 소요될 수 있어요!(크돌이는 돈이 없어요...)", inline=False)
|
||||||
embed.set_footer()
|
embed.set_footer()
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evalate(tags, rating) -> ChristmasEmbed:
|
||||||
|
if tags == None:
|
||||||
|
embed = ChristmasEmbed(name="**오류**", value="그림 분석에 실패했어요. 나중에 다시 시도해주세요", inline=False)
|
||||||
|
return embed
|
||||||
|
else:
|
||||||
|
embed = ChristmasEmbed(title="**그림 분석 완료**", description="그림 분석이 완료되었어요!")
|
||||||
|
embed.set_image(url="attachment://image.png")
|
||||||
|
embed.add_field(name="**등급**", value=f"{rating}", inline=False)
|
||||||
|
texts = ""
|
||||||
|
for tag, score in tags:
|
||||||
|
percentage = score * 100
|
||||||
|
texts = texts + f"{tag} : {percentage:.1f}%\n"
|
||||||
|
embed.add_field(name="**분석된 요소**", value=f"{texts}", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
|
class Music_Embed:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def music_notenable():
|
||||||
|
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
|
||||||
|
embed.add_field(name="안내", value="이 서버에서는 음악을 재생할수 없어요! \n만약 서버의 관리자라면 ``/설정`` 명령어로 음악 기능을 다시 활성화사킬수 있어요!", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def author_not_voice():
|
||||||
|
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
|
||||||
|
embed.add_field(name="안내", value="음악을 재생하려면 음성채널에 들어가야 해요!", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def author_diffrent_voice():
|
||||||
|
embed = ChristmasEmbed(title="❌ 음악 재생 실패!", description="음악 재생에 실패했어요!")
|
||||||
|
embed.add_field(name="안내", value="이미 다른 채널에 접속되어있어요!", inline=False)
|
||||||
|
embed.set_footer()
|
||||||
|
return embed
|
||||||
|
|
||||||
|
#@staticmethod
|
||||||
|
#def music_queue(music: wavelink.Playable):
|
||||||
|
#음악을 queue에 넣음
|
||||||
|
#embed = ChristmasEmbed(title="✅ 음악 재생 성공!", description="음악 재생에 성공했어요!")
|
||||||
|
|
||||||
|
#embed =
|
@ -6,6 +6,7 @@ from discord.ui import Modal, InputText
|
|||||||
from Christmas.UI.Embed import Mail_Embed, Aiart_Embed
|
from Christmas.UI.Embed import Mail_Embed, Aiart_Embed
|
||||||
from Christmas.UI.Buttons import Mail_Confirm_Button
|
from Christmas.UI.Buttons import Mail_Confirm_Button
|
||||||
from Christmas.Module import process_prompt, post_gpu_server, base64_to_image
|
from Christmas.Module import process_prompt, post_gpu_server, base64_to_image
|
||||||
|
from Christmas.config import ChristmasConfig
|
||||||
|
|
||||||
BLOCKTAG = [
|
BLOCKTAG = [
|
||||||
"nsfw",
|
"nsfw",
|
||||||
@ -40,6 +41,7 @@ class Send_Mail_Modal(Modal):
|
|||||||
def __init__(self, reciveuser: Member, editmode: bool, *args, **kwargs):
|
def __init__(self, reciveuser: Member, editmode: bool, *args, **kwargs):
|
||||||
self.reciveuser = reciveuser
|
self.reciveuser = reciveuser
|
||||||
self.editmode = editmode
|
self.editmode = editmode
|
||||||
|
|
||||||
super().__init__(timeout=None, *args, **kwargs)
|
super().__init__(timeout=None, *args, **kwargs)
|
||||||
|
|
||||||
self.add_item(InputText(label="제목", placeholder="제목을 입력해주세요.", style=InputTextStyle.short, required=True, custom_id="mail_title"))
|
self.add_item(InputText(label="제목", placeholder="제목을 입력해주세요.", style=InputTextStyle.short, required=True, custom_id="mail_title"))
|
||||||
@ -58,6 +60,7 @@ class Aiart(Modal):
|
|||||||
self.res = res
|
self.res = res
|
||||||
self.style1 = style1
|
self.style1 = style1
|
||||||
self.style2 = style2
|
self.style2 = style2
|
||||||
|
self.config = ChristmasConfig()
|
||||||
self.afterprocess = afterprocess
|
self.afterprocess = afterprocess
|
||||||
super().__init__(timeout=None, *args, **kwargs)
|
super().__init__(timeout=None, *args, **kwargs)
|
||||||
|
|
||||||
@ -81,7 +84,7 @@ class Aiart(Modal):
|
|||||||
await interaction.response.send_message(embed=Aiart_Embed.generating(), ephemeral=self.show)
|
await interaction.response.send_message(embed=Aiart_Embed.generating(), ephemeral=self.show)
|
||||||
#prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float
|
#prompt: str, remove: str, res: list, isnsfw: bool, style1: float, style2: float, afterprocess: float
|
||||||
payload = await process_prompt(prompt, remove, self.res, self.allownsfw, self.style1, self.style2, self.afterprocess)
|
payload = await process_prompt(prompt, remove, self.res, self.allownsfw, self.style1, self.style2, self.afterprocess)
|
||||||
result = await post_gpu_server("http://172.30.1.49:7860/sdapi/v1/txt2img", payload)
|
result = await post_gpu_server(f"{ChristmasConfig.AI()}/sdapi/v1/txt2img", payload)
|
||||||
if result["status"] != True:
|
if result["status"] != True:
|
||||||
return await interaction.edit_original_response(embed=Aiart_Embed.failed_generate())
|
return await interaction.edit_original_response(embed=Aiart_Embed.failed_generate())
|
||||||
else:
|
else:
|
||||||
|
@ -32,3 +32,11 @@ class ChristmasConfig:
|
|||||||
@property
|
@property
|
||||||
def DATABASE(self):
|
def DATABASE(self):
|
||||||
return self.json["DATABASE"]
|
return self.json["DATABASE"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def LAVALINK(self):
|
||||||
|
return self.json["LAVALINKS"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def AI(self):
|
||||||
|
return self.json["AI_GATEWAY"]
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import mafic
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@ -21,8 +22,6 @@ class Christmas(AutoShardedBot):
|
|||||||
|
|
||||||
def load_cogs(bot) -> None:
|
def load_cogs(bot) -> None:
|
||||||
for filename in os.listdir("Christmas/Cogs"):
|
for filename in os.listdir("Christmas/Cogs"):
|
||||||
if filename == "__pycache__":
|
|
||||||
continue
|
|
||||||
if filename.endswith(".py"):
|
if filename.endswith(".py"):
|
||||||
bot.load_extension(f"Christmas.Cogs.{filename[:-3]}")
|
bot.load_extension(f"Christmas.Cogs.{filename[:-3]}")
|
||||||
|
|
||||||
|
@ -5,3 +5,5 @@ korcen
|
|||||||
nanoid
|
nanoid
|
||||||
pendulum
|
pendulum
|
||||||
Wavelink
|
Wavelink
|
||||||
|
onnxruntime
|
||||||
|
psutil
|
Loading…
Reference in New Issue
Block a user