Instructions to use ytgui/Search-R3.0-Small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ytgui/Search-R3.0-Small with Transformers:
# Load model directly from transformers import AutoTokenizer, HFQwen2LM tokenizer = AutoTokenizer.from_pretrained("ytgui/Search-R3.0-Small") model = HFQwen2LM.from_pretrained("ytgui/Search-R3.0-Small") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from torch import nn | |
| from tqdm import tqdm | |
| from torch.nn import functional as F | |
| from transformers import ( | |
| set_seed, pipeline, AutoTokenizer, AutoModelForCausalLM | |
| ) | |
| EMBEDDING = """ | |
| You are a helpful AI assistant. Your task is to analyze input text and create a high-quality semantic vector embedding, which represents key concepts, relationships, and semantic meaning. | |
| """ | |
| GENERATION = """ | |
| You are a helpful AI assistant. Your task is to enrich user input for more effective embedding representation by adding semantic depth. | |
| For each input, briefly enhance the content by: | |
| 1. Identifying core concepts and their relationships. | |
| 2. Including key terminology with essential definitions. | |
| 3. Adding contextually relevant synonyms and related terms. | |
| 4. Connecting to related topics and common applications without excessive elaboration. | |
| To represent the final embedding, you MUST end every response with <|embed_token|>. | |
| """ | |
| class SearchR3(nn.Module): | |
| def __init__(self, | |
| path: str, | |
| max_length: int, | |
| batch_size: int): | |
| nn.Module.__init__(self) | |
| # | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| path, torch_dtype='auto', device_map='auto' | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| path, truncation_side='left', padding_side='left' | |
| ) | |
| self.embed_token = self.tokenizer.encode('<|embed_token|>')[0] | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| def device(self): | |
| return next(self.model.parameters()).device | |
| def generate(self, batch: list[str]): | |
| if not isinstance(batch, (list, tuple)): | |
| raise ValueError('batch type is incorrect') | |
| if any(not isinstance(v, str) for v in batch): | |
| raise ValueError('batch item type is incorrect') | |
| # batch | |
| if len(batch) > self.batch_size: | |
| outputs = [] | |
| for i in tqdm( | |
| range(0, len(batch), self.batch_size) | |
| ): | |
| outputs.extend( | |
| self.generate( | |
| batch[i:i + self.batch_size] | |
| ) | |
| ) | |
| return outputs | |
| # tokenize | |
| messages = [ | |
| [ | |
| {'role': 'system', 'content': GENERATION.strip()}, | |
| {'role': 'user', 'content': item} | |
| ] | |
| for item in batch | |
| ] | |
| context = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer( | |
| context, padding='longest', truncation=True, | |
| return_tensors='pt', max_length=self.max_length // 2 | |
| ) | |
| prompt_length = inputs['input_ids'].size(-1) | |
| # generate | |
| self.model.eval() | |
| outputs = self.model.generate( | |
| **inputs.to(device=self.device), | |
| max_new_tokens=self.max_length - prompt_length | |
| ) | |
| outputs = self.tokenizer.batch_decode( | |
| outputs[:, prompt_length:], skip_special_tokens=False | |
| ) | |
| # cleanup | |
| for special_token in self.tokenizer.all_special_tokens: | |
| if special_token == '<|embed_token|>': | |
| continue | |
| outputs = [ | |
| item.replace(special_token, '') for item in outputs | |
| ] | |
| messages = [ | |
| item + [ | |
| {'role': 'assistant', 'content': outputs[i].strip()} | |
| ] | |
| for i, item in enumerate(messages) | |
| ] | |
| return messages | |
| def format(self, batch: list[str]): | |
| if any(not isinstance(v, str) for v in batch): | |
| raise RuntimeError('batch type is incorrect') | |
| return [ | |
| [ | |
| {'role': 'system', 'content': EMBEDDING.strip()}, | |
| {'role': 'user', 'content': item}, | |
| {'role': 'assistant', 'content': 'The embedding is: <|embed_token|>'} | |
| ] | |
| for item in batch | |
| ] | |
| def encode(self, batch: list[any]): | |
| if not isinstance(batch, (list, tuple)): | |
| raise ValueError('batch type is incorrect') | |
| # batch | |
| if len(batch) > self.batch_size: | |
| outputs = [ | |
| self.encode( | |
| batch[i:i + self.batch_size] | |
| ) | |
| for i in tqdm( | |
| range(0, len(batch), self.batch_size) | |
| ) | |
| ] | |
| return torch.cat(outputs, dim=0) | |
| # format | |
| if all(isinstance(v, str) for v in batch): | |
| batch = self.format(batch=batch) | |
| # validate | |
| if any( | |
| m[-1]['role'] != 'assistant' for m in batch | |
| ): | |
| raise RuntimeError('unexpected role') | |
| if any( | |
| m[-2]['role'] != 'user' for m in batch | |
| ): | |
| raise RuntimeError('unexpected role') | |
| # ensure <embed_token> | |
| batch = [ | |
| m if '<|embed_token|>' in m[-1]['content'] | |
| else self.format([m[-2]['content']])[0] | |
| for m in batch | |
| ] | |
| if any( | |
| '<|embed_token|>' not in m[-1]['content'] for m in batch | |
| ): | |
| raise RuntimeError('unexpected embed token') | |
| # tokenize | |
| context = self.tokenizer.apply_chat_template( | |
| batch, tokenize=False, add_generation_prompt=False | |
| ) | |
| inputs = self.tokenizer( | |
| context, padding='longest', truncation=True, | |
| return_tensors='pt', max_length=self.max_length | |
| ) | |
| # forward | |
| self.model.eval() | |
| outputs = self.model( | |
| **inputs.to(device=self.device), | |
| return_dict=True, output_hidden_states=True | |
| ) | |
| hidden_state = outputs['hidden_states'][-1] | |
| # pooling | |
| length = inputs['input_ids'].size(-1) | |
| valid_mask = torch.arange(length, device=self.device) | |
| valid_mask = torch.where( | |
| valid_mask.unsqueeze(0) > length - 5, True, False | |
| ) | |
| embed_mask = torch.where( | |
| inputs['input_ids'] == self.embed_token, True, False | |
| ) | |
| embed_mask = embed_mask.logical_and(valid_mask) | |
| return F.normalize( | |
| hidden_state[embed_mask].cpu().float(), dim=-1 | |
| ) | |
| def main(): | |
| # init | |
| set_seed(42) | |
| from pprint import pprint | |
| # basic | |
| generator = pipeline( | |
| task='text-generation', | |
| model='ytgui/Search-R3.0-Small', | |
| torch_dtype='auto', device_map='auto' | |
| ) | |
| messages = [ | |
| {"role": 'user', 'content': 'Who are you?'}, | |
| ] | |
| response = generator(messages, max_new_tokens=256) | |
| pprint(response) | |
| # reasoning | |
| model = SearchR3( | |
| 'ytgui/Search-R3.0-Small', max_length=1024, batch_size=8 | |
| ) | |
| reasoning = model.generate( | |
| batch=['what python library is useful for data analysis?'] | |
| ) | |
| pprint(reasoning) | |
| # embedding | |
| documents = [ | |
| 'pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool, built on top of the Python programming language.', | |
| 'The giant panda (Ailuropoda melanoleuca), also known as the panda bear or simply panda, is a bear species endemic to China. It is characterised by its white coat with black patches around the eyes, ears, legs and shoulders.', | |
| ] | |
| E_d = model.encode(batch=documents) | |
| E_q = model.encode(batch=reasoning) | |
| print('distance:', torch.cdist(E_q, E_d, p=2.0)) | |
| if __name__ == '__main__': | |
| main() | |