Fix ctransformers model unload (#3711)

Add missing comma in model types list

Fixes marella/ctransformers#111
This commit is contained in:
Ravindra Marella 2023-08-27 19:23:48 +05:30 committed by GitHub
parent 0c9e818bb8
commit e4c3e1bdd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 5 deletions

View File

@ -10,8 +10,8 @@ class CtransformersModel:
pass pass
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(cls, path):
result = self() result = cls()
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
str(path), str(path),
@ -24,13 +24,13 @@ class CtransformersModel:
mlock=shared.args.mlock mlock=shared.args.mlock
) )
self.model = AutoModelForCausalLM.from_pretrained( result.model = AutoModelForCausalLM.from_pretrained(
str(result.model_dir(path) if result.model_type_is_auto() else path), str(result.model_dir(path) if result.model_type_is_auto() else path),
model_type=(None if result.model_type_is_auto() else shared.args.model_type), model_type=(None if result.model_type_is_auto() else shared.args.model_type),
config=config config=config
) )
logger.info(f'Using ctransformers model_type: {self.model.model_type} for {self.model.model_path}') logger.info(f'Using ctransformers model_type: {result.model.model_type} for {result.model.model_path}')
return result, result return result, result
def model_type_is_auto(self): def model_type_is_auto(self):

View File

@ -304,7 +304,7 @@ loaders_model_types = {
"gptneox", "gptneox",
"llama", "llama",
"mpt", "mpt",
"dollyv2" "dollyv2",
"replit", "replit",
"starcoder", "starcoder",
"gptbigcode", "gptbigcode",